Skip to content

Commit f9556dc

Browse files
committed
use open_files reader to read multiple files
1 parent a6a7b6f commit f9556dc

File tree

5 files changed

+37
-33
lines changed

5 files changed

+37
-33
lines changed

doc/v2/howto/recordio/README.md renamed to doc/fluid/howto/cluster/fluid_recordio.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,14 @@ The above codes would generate multiple RecordIO files on your host like:
8989

9090
```bash
9191
.
92-
\_mnist.recordio-00000
93-
|-mnist.recordio-00001
94-
|-mnist.recordio-00002
95-
|-mnist.recordio-00003
96-
|-mnist.recordio-00004
92+
\_mnist-00000.recordio
93+
|-mnist-00001.recordio
94+
|-mnist-00002.recordio
95+
|-mnist-00003.recordio
96+
|-mnist-00004.recordio
9797
```
9898

99-
1. read these RecordIO files with `fluid.layers.io.open_recordio_file`
99+
1. open multiple RecordIO files by `fluid.layers.io.open_files`
100100

101101
For a distributed training job, the distributed operator system will schedule trainer process on multiple nodes,
102102
each trainer process reads parts of the whole training data, we usually take the following approach to make the training
@@ -113,10 +113,12 @@ def gen_train_list(file_pattern, trainers, trainer_id):
113113

114114
trainers = int(os.getenv("TRAINERS"))
115115
trainer_id = int(os.getenv("PADDLE_INIT_TRAINER_ID"))
116-
data_file = fluid.layers.io.open_recordio_file(
117-
filename=gen_train_list("./mnist.recordio*", trainers, trainer_id),
118-
shapes=[(-1, 784),(-1, 1)],
119-
lod_levels=[0, 0],
120-
dtypes=["float32", "int32"])
121-
data_file = fluid.layers.io.batch(data_file, batch_size=4)
116+
data_file = fluid.layers.io.open_files(
117+
filenames=gen_train_list("./mnist-[0-9]*.recordio", 2, 0),
118+
thread_num=1,
119+
shapes=[(-1, 784),(-1, 1)],
120+
lod_levels=[0, 0],
121+
dtypes=["float32", "int32"])
122+
img, label = fluid.layers.io.read_file(data_files)
123+
...
122124
```

paddle/fluid/operators/reader/create_recordio_file_reader_op.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,20 @@ class CreateRecordIOReaderOp : public framework::OperatorBase {
6565
static_cast<int>(shape_concat.size()),
6666
"The accumulate of all ranks should be equal to the "
6767
"shape concat's length.");
68-
auto filenames = Attr<std::vector<std::string>>("filenames");
68+
std::string filename = Attr<std::string>("filename");
6969

7070
auto* out = scope.FindVar(Output("Out"))
7171
->template GetMutable<framework::ReaderHolder>();
72-
for (auto& fn : filenames) {
73-
out->Reset(
74-
new RecordIOFileReader<true>(fn, RestoreShapes(shape_concat, ranks)));
75-
}
72+
73+
out->Reset(new RecordIOFileReader<true>(
74+
filename, RestoreShapes(shape_concat, ranks)));
7675
}
7776
};
7877

7978
class CreateRecordIOReaderOpMaker : public FileReaderMakerBase {
8079
protected:
8180
void Apply() override {
82-
AddAttr<std::vector<std::string>>("filenames",
83-
"The filenames of record io reader");
81+
AddAttr<std::string>("filename", "The filename of record io reader");
8482
AddComment(R"DOC(
8583
CreateRecordIOReader Operator
8684

python/paddle/fluid/layers/io.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..executor import global_scope
2222

2323
__all__ = [
24-
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_files',
24+
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
2525
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
2626
'random_data_generator', 'Preprocessor'
2727
]
@@ -291,20 +291,20 @@ def _copy_reader_create_op_(block, op):
291291
return new_op
292292

293293

294-
def open_recordio_files(filenames,
295-
shapes,
296-
lod_levels,
297-
dtypes,
298-
pass_num=1,
299-
for_parallel=True):
294+
def open_recordio_file(filename,
295+
shapes,
296+
lod_levels,
297+
dtypes,
298+
pass_num=1,
299+
for_parallel=True):
300300
"""
301301
Open a RecordIO file
302302
303303
This layer takes a RecordIO file to read from and returns a Reader Variable.
304304
Via the Reader Variable, we can get data from the given RecordIO file.
305305
306306
Args:
307-
filename(str) or list(str): The RecordIO file's name.
307+
filename(str): The RecordIO file's name.
308308
shapes(list): List of tuples which declaring data shapes.
309309
lod_levels(list): List of ints which declaring data lod_level.
310310
dtypes(list): List of strs which declaring data type.
@@ -336,8 +336,6 @@ def open_recordio_files(filenames,
336336
ranks.append(len(shape))
337337

338338
var_name = unique_name('open_recordio_file')
339-
if isinstance(filenames, str):
340-
filenames = [filenames]
341339

342340
startup_blk = default_startup_program().current_block()
343341
startup_var = startup_blk.create_var(name=var_name)
@@ -347,7 +345,7 @@ def open_recordio_files(filenames,
347345
attrs={
348346
'shape_concat': shape_concat,
349347
'lod_levels': lod_levels,
350-
'filenames': filenames,
348+
'filename': filename,
351349
'ranks': ranks
352350
})
353351

python/paddle/fluid/recordio_writer.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import core
1617
import contextlib
17-
__all__ = ['convert_reader_to_recordio_file']
18+
__all__ = [
19+
'convert_reader_to_recordio_file', 'convert_reader_to_recordio_files'
20+
]
1821

1922

2023
@contextlib.contextmanager
@@ -48,7 +51,7 @@ def convert_reader_to_recordio_file(
4851

4952

5053
def convert_reader_to_recordio_files(
51-
filename_suffix,
54+
filename,
5255
batch_per_file,
5356
reader_creator,
5457
feeder,
@@ -57,13 +60,16 @@ def convert_reader_to_recordio_files(
5760
feed_order=None):
5861
if feed_order is None:
5962
feed_order = feeder.feed_names
63+
f_name, f_ext = os.path.splitext(filename)
64+
assert (f_ext == ".recordio")
65+
6066
lines = []
6167
f_idx = 0
6268
counter = 0
6369
for idx, batch in enumerate(reader_creator()):
6470
lines.append(batch)
6571
if idx >= batch_per_file and idx % batch_per_file == 0:
66-
filename = "%s-%05d" % (filename_suffix, f_idx)
72+
filename = "%s-%05d%s" % (f_name, f_idx, f_ext)
6773
with create_recordio_writer(filename, compressor,
6874
max_num_records) as writer:
6975
for l in lines:

tools/codestyle/docstring_checker.pyc

11.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)