Skip to content

Commit 530556d

Browse files
authored
Merge pull request #10864 from JiayiFeng/dev_expose_random_gen
expose random_data_generator
2 parents 728621e + c2436f2 commit 530556d

File tree

3 files changed

+84
-16
lines changed

3 files changed

+84
-16
lines changed

paddle/fluid/operators/reader/create_random_data_generator_op.cc

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,15 @@ namespace reader {
2121
template <typename T>
2222
class RandomDataGenerator : public framework::ReaderBase {
2323
public:
24-
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float min,
25-
float max)
26-
: framework::ReaderBase(), min_(min), max_(max), shapes_(shapes) {
27-
PADDLE_ENFORCE_LE(
28-
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
24+
RandomDataGenerator(const std::vector<framework::DDim>& shapes, float low,
25+
float high)
26+
: framework::ReaderBase(), low_(low), high_(high), shapes_(shapes) {
27+
PADDLE_ENFORCE_LE(low, high,
28+
"'low' shouldn't be greater than 'high'.(%f vs %f)", low,
29+
high);
2930
unsigned int seed = std::random_device()();
3031
engine_.seed(seed);
31-
dist_ = std::uniform_real_distribution<float>(min_, max_);
32+
dist_ = std::uniform_real_distribution<float>(low_, high_);
3233
}
3334

3435
void ReadNext(std::vector<framework::LoDTensor>* out) override {
@@ -53,8 +54,8 @@ class RandomDataGenerator : public framework::ReaderBase {
5354
void ReInit() override { return; }
5455

5556
private:
56-
float min_;
57-
float max_;
57+
float low_;
58+
float high_;
5859
std::minstd_rand engine_;
5960
std::uniform_real_distribution<float> dist_;
6061
std::vector<framework::DDim> shapes_;
@@ -78,22 +79,22 @@ class CreateRandomDataGeneratorOp : public framework::OperatorBase {
7879
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
7980
auto* out = scope.FindVar(Output("Out"))
8081
->template GetMutable<framework::ReaderHolder>();
81-
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("min"),
82-
Attr<float>("max")));
82+
out->Reset(new RandomDataGenerator<T>(shapes, Attr<float>("low"),
83+
Attr<float>("high")));
8384
}
8485
};
8586

8687
class CreateRandomDataGeneratorOpMaker : public FileReaderMakerBase {
8788
protected:
8889
void Apply() override {
89-
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
90-
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
90+
AddAttr<float>("low", "The lower bound of reader's uniform distribution.");
91+
AddAttr<float>("high", "The upper bound of reader's uniform distribution.");
9192
AddComment(R"DOC(
9293
CreateRandomDataGenerator Operator
9394
9495
This Op creates a random reader.
9596
The reader generates random data instead of really reading from files.
96-
Generated data follow an uniform distribution between 'min' and 'max'.
97+
Generated data follow an uniform distribution between 'low' and 'high'.
9798
)DOC");
9899
}
99100
};

python/paddle/fluid/layers/io.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def open_recordio_file(filename,
321321
dtypes=['float32', 'int64'])
322322
323323
# Via the reader, we can use 'read_file' layer to get data:
324-
image, label = fluid.layers.read_file(reader)
324+
image, label = fluid.layers.io.read_file(reader)
325325
"""
326326
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
327327
shape_concat = []
@@ -359,6 +359,73 @@ def open_recordio_file(filename,
359359
return monkey_patch_reader_methods(main_prog_var)
360360

361361

362+
def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
363+
"""
364+
Create a uniform random data generator
365+
366+
This layer returns a Reader Variable.
367+
Instead of opening a file and reading data from it, this
368+
Reader Variable generates float uniform random data by itself.
369+
It can be used as a dummy reader to test a network without
370+
opening a real file.
371+
372+
Args:
373+
low(float): The lower bound of data's uniform distribution.
374+
high(float): The upper bound of data's uniform distribution.
375+
shapes(list): List of tuples which declaring data shapes.
376+
lod_levels(list): List of ints which declaring data lod_level.
377+
for_parallel(Bool): Set it as True if you are going to run
378+
subsequent operators in parallel.
379+
380+
Returns:
381+
Variable: A Reader Variable from which we can get random data.
382+
383+
Examples:
384+
.. code-block:: python
385+
386+
reader = fluid.layers.io.random_data_generator(
387+
low=0.0,
388+
high=1.0,
389+
shapes=[(3,224,224), (1)],
390+
lod_levels=[0, 0])
391+
392+
# Via the reader, we can use 'read_file' layer to get data:
393+
image, label = fluid.layers.io.read_file(reader)
394+
"""
395+
dtypes = [core.VarDesc.VarType.FP32] * len(shapes)
396+
shape_concat = []
397+
ranks = []
398+
399+
for shape in shapes:
400+
shape_concat.extend(shape)
401+
ranks.append(len(shape))
402+
403+
var_name = unique_name('random_data_generator')
404+
405+
startup_blk = default_startup_program().current_block()
406+
startup_var = startup_blk.create_var(name=var_name)
407+
startup_blk.append_op(
408+
type='create_random_data_generator',
409+
outputs={'Out': [startup_var]},
410+
attrs={
411+
'low': low,
412+
'high': high,
413+
'shape_concat': shape_concat,
414+
'lod_levels': lod_levels,
415+
'ranks': ranks
416+
})
417+
418+
startup_var.desc.set_dtypes(dtypes)
419+
startup_var.persistable = True
420+
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
421+
startup_var)
422+
423+
if for_parallel:
424+
main_prog_var = parallel(reader=main_prog_var)
425+
426+
return monkey_patch_reader_methods(main_prog_var)
427+
428+
362429
def open_files(filenames,
363430
shapes,
364431
lod_levels,

python/paddle/fluid/tests/test_cpp_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@
4444
attrs={
4545
"shape_concat": [1, 2, 1, 1],
4646
"ranks": [2, 2],
47-
"min": 0.0,
48-
"max": 1.0,
47+
"low": 0.0,
48+
"high": 1.0,
4949
'lod_levels': [0, 0]
5050
})
5151

0 commit comments

Comments
 (0)