Skip to content

Commit ebe3b5e

Browse files
authored
Merge pull request #11853 from sneaxiy/complete_py_reader_python
Add Python Reader Op (Python side and unittests)
2 parents a0530c3 + 0fef252 commit ebe3b5e

File tree

9 files changed

+439
-18
lines changed

9 files changed

+439
-18
lines changed

paddle/fluid/framework/reader.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ enum ReaderStatus { kRunning, kStopped };
2929

3030
class ReaderBase {
3131
public:
32-
void ReadNext(std::vector<LoDTensor>* out);
32+
virtual void ReadNext(std::vector<LoDTensor>* out);
3333

34-
void Shutdown();
34+
virtual void Shutdown();
3535

36-
void Start();
36+
virtual void Start();
3737

3838
// Return the readers which are the end of decorating chain. Basically
3939
// they are readers just before read op.
@@ -42,7 +42,7 @@ class ReaderBase {
4242
virtual ~ReaderBase();
4343

4444
protected:
45-
virtual void ReadNextImpl(std::vector<LoDTensor>* out) = 0;
45+
virtual void ReadNextImpl(std::vector<LoDTensor>* out) {}
4646

4747
virtual void ShutdownImpl() {}
4848

paddle/fluid/operators/reader/blocking_queue.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,15 @@ class BlockingQueue {
8181
}
8282
}
8383

84+
void ReOpen() {
85+
std::lock_guard<std::mutex> lock(mutex_);
86+
closed_ = false;
87+
std::deque<T> new_deque;
88+
queue_.swap(new_deque);
89+
send_cv_.notify_all();
90+
receive_cv_.notify_all();
91+
}
92+
8493
void Close() {
8594
std::lock_guard<std::mutex> lock(mutex_);
8695
closed_ = true;

paddle/fluid/operators/reader/create_py_reader_op.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,17 @@ class PyReader : public framework::FileReader {
2727
queue_ = queue;
2828
}
2929

30-
void ReadNextImpl(std::vector<framework::LoDTensor>* out) override {
30+
void ReadNext(std::vector<framework::LoDTensor>* out) override {
3131
bool success;
3232
*out = queue_->Pop(&success);
3333
if (!success) out->clear();
3434
}
3535

36-
private:
37-
void ShutdownImpl() override { /* TODO */
38-
}
36+
void Shutdown() override { queue_->Close(); }
3937

40-
void StartImpl() override { /* TODO */
41-
}
38+
void Start() override { queue_->ReOpen(); }
4239

40+
private:
4341
std::shared_ptr<LoDTensorBlockingQueue> queue_;
4442
};
4543

paddle/fluid/operators/reader/lod_tensor_blocking_queue.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,15 @@ class LoDTensorBlockingQueue {
5858

5959
inline size_t Size() const { return queue_.Size(); }
6060

61-
inline void Close() { return queue_.Close(); }
61+
inline void ReOpen() { queue_.ReOpen(); }
62+
63+
inline void Close() { queue_.Close(); }
6264

6365
inline bool IsClosed() const { return queue_.IsClosed(); }
6466

6567
private:
66-
void CheckDims(const std::vector<framework::LoDTensor>& lod_tensor_vec) {
68+
void CheckDims(
69+
const std::vector<framework::LoDTensor>& lod_tensor_vec) const {
6770
PADDLE_ENFORCE(dims_.size() == lod_tensor_vec.size(),
6871
"Expect input size is %d but found %s", dims_.size(),
6972
lod_tensor_vec.size());

paddle/fluid/pybind/pybind.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
#include <Python.h>
1515
#include <algorithm>
1616
#include <map>
17+
#include <memory>
1718
#include <mutex> // NOLINT // for call_once
1819
#include <string>
1920
#include <unordered_map>
@@ -310,7 +311,8 @@ All parameter, weight, gradient are variables in Paddle.
310311
::paddle::operators::reader::LoDTensorBlockingQueue;
311312
using LoDTensorBlockingQueueHolder =
312313
::paddle::operators::reader::LoDTensorBlockingQueueHolder;
313-
py::class_<LoDTensorBlockingQueue>(m, "LoDTensorBlockingQueue", "")
314+
py::class_<LoDTensorBlockingQueue, std::shared_ptr<LoDTensorBlockingQueue>>(
315+
m, "LoDTensorBlockingQueue", "")
314316
.def("push",
315317
[](LoDTensorBlockingQueue &self,
316318
const std::vector<framework::LoDTensor> &lod_tensor_vec) {
@@ -325,17 +327,17 @@ All parameter, weight, gradient are variables in Paddle.
325327
m.def("init_lod_tensor_blocking_queue",
326328
[](Variable &var, size_t capacity,
327329
const std::vector<std::vector<int64_t>> &shapes)
328-
-> LoDTensorBlockingQueue * {
330+
-> std::shared_ptr<LoDTensorBlockingQueue> {
329331
std::vector<DDim> dims(shapes.size());
330332
std::transform(shapes.begin(), shapes.end(), dims.begin(),
331333
[](const std::vector<int64_t> &shape) {
332334
return make_ddim(shape);
333335
});
334336
auto *holder = var.GetMutable<LoDTensorBlockingQueueHolder>();
335337
holder->InitOnce(capacity, dims);
336-
return holder->GetQueue().get();
338+
return holder->GetQueue();
337339
},
338-
py::return_value_policy::reference);
340+
py::return_value_policy::copy);
339341

340342
py::class_<Scope>(m, "Scope", "")
341343
.def("var",
@@ -543,6 +545,8 @@ All parameter, weight, gradient are variables in Paddle.
543545
});
544546

545547
py::class_<LoDTensorArray>(m, "LoDTensorArray")
548+
.def("__init__",
549+
[](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); })
546550
.def("__getitem__",
547551
[](LoDTensorArray &self, size_t i) { return &self.at(i); },
548552
py::return_value_policy::reference)

python/paddle/fluid/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
import transpiler
4545
from param_attr import ParamAttr, WeightNormParamAttr
4646
from data_feeder import DataFeeder
47-
from core import LoDTensor, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
47+
from core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
4848
from transpiler import DistributeTranspiler, InferenceTranspiler, \
4949
memory_optimize, release_memory
5050
from concurrency import (Go, make_channel, channel_send, channel_recv,
@@ -72,6 +72,7 @@
7272
'backward',
7373
'regularizer',
7474
'LoDTensor',
75+
'LoDTensorArray',
7576
'CPUPlace',
7677
'CUDAPlace',
7778
'CUDAPinnedPlace',

python/paddle/fluid/layers/io.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
__all__ = [
2525
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
2626
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
27-
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
27+
'double_buffer', 'random_data_generator', 'py_reader', 'Preprocessor',
28+
'load'
2829
]
2930

3031

@@ -445,6 +446,88 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
445446
return monkey_patch_reader_methods(main_prog_var)
446447

447448

449+
def py_reader(capacity, shapes, dtypes, lod_levels=None):
450+
"""
451+
Create a reader and blocking queue for data feeding in Python
452+
453+
This layer returns a Reader Variable and a BlockingQueue.
454+
The BlockingQueue provides `push()` method to push a `LoDTensorArray`
455+
object into the queue in Python side. In C++ side, the Reader
456+
Variable would invoke `pop()` method of the queue to retrieve the
457+
feeding data. The process of feeding data in Python side and fetching
458+
data in C++ side can run in parallel. The BlockingQueue should be closed
459+
using `close()` method when unused.
460+
461+
Args:
462+
capacity(int): The maximum capacity of the BlockingQueue.
463+
shapes(list): List of tuples which declaring data shapes.
464+
dtypes(list): List of strs which declaring data type.
465+
lod_levels(list): List of ints which declaring data lod_level.
466+
467+
Returns:
468+
tuple(Variable, BlockingQueue):
469+
A Reader Variable from which we can get feeding data.
470+
471+
A BlockingQueue object for data feeding.
472+
473+
Examples:
474+
475+
.. code-block:: python
476+
477+
reader, queue = fluid.layers.py_reader(
478+
capacity=10,
479+
shapes=[[-1,3,224,224], [-1,1]],
480+
dtypes=['float32', 'int64'])
481+
# Via the reader, we can use 'read_file' layer to get data:
482+
image, label = fluid.layers.read_file(reader)
483+
484+
# Via the blocking queue, we can feed data using threads
485+
def feed_data(queue, feed_images, feed_labels):
486+
for feed_image, feed_label in zip(feed_images, feed_labels):
487+
data = core.LoDTensorArray()
488+
data.append(feed_image)
489+
data.append(feed_label)
490+
queue.push(data)
491+
492+
thread = threading.Thread(target=feed_data, args=(queue, feed_images, feed_labels))
493+
thread.start()
494+
"""
495+
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
496+
shape_concat = []
497+
ranks = []
498+
499+
for shape in shapes:
500+
shape_concat.extend(shape)
501+
ranks.append(len(shape))
502+
503+
if lod_levels is None:
504+
lod_levels = [0] * len(shapes)
505+
506+
queue_name = unique_name('lod_tensor_blocking_queue')
507+
var = global_scope().var(queue_name)
508+
feed_queue = core.init_lod_tensor_blocking_queue(var, capacity, shapes)
509+
510+
startup_blk = default_startup_program().current_block()
511+
startup_var = startup_blk.create_var(name=unique_name('create_py_reader'))
512+
startup_blk.append_op(
513+
type='create_py_reader',
514+
inputs={'blocking_queue': queue_name},
515+
outputs={'Out': [startup_var]},
516+
attrs={
517+
'shape_concat': shape_concat,
518+
'lod_levels': lod_levels,
519+
'ranks': ranks
520+
})
521+
522+
startup_var.desc.set_dtypes(dtypes)
523+
startup_var.persistable = True
524+
525+
main_prog_var = _copy_reader_var_(default_main_program().current_block(),
526+
startup_var)
527+
528+
return monkey_patch_reader_methods(main_prog_var), feed_queue
529+
530+
448531
def open_files(filenames,
449532
shapes,
450533
lod_levels,
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle.fluid as fluid
17+
import numpy as np
18+
from threading import Thread
19+
20+
21+
def feed_data(feed_queue, inputs):
22+
for in_data in inputs:
23+
feed_queue.push(in_data)
24+
25+
26+
class TestPyReader(unittest.TestCase):
27+
def setUp(self):
28+
self.capacity = 10
29+
self.batch_size_min = 10
30+
self.batch_size_max = 20
31+
self.shapes = [(-1, 3, 2, 1), (-1, 1)]
32+
self.lod_levels = [0, 0]
33+
self.dtypes = ['float32', 'int64']
34+
self.iterations = 20
35+
36+
def test_single_thread_main(self):
37+
self.main(use_thread=False)
38+
39+
def test_multiple_thread_main(self):
40+
self.main(use_thread=True)
41+
42+
def main(self, use_thread=False):
43+
with fluid.program_guard(fluid.Program(), fluid.Program()):
44+
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
45+
) else fluid.CPUPlace()
46+
executor = fluid.Executor(place)
47+
48+
data_file, feed_queue = fluid.layers.py_reader(
49+
capacity=self.capacity,
50+
dtypes=self.dtypes,
51+
lod_levels=self.lod_levels,
52+
shapes=self.shapes)
53+
54+
read_out_data = fluid.layers.read_file(data_file)
55+
self.inputs = []
56+
57+
for i in range(self.iterations):
58+
in_data = fluid.LoDTensorArray()
59+
batch_size = np.random.random_integers(self.batch_size_min,
60+
self.batch_size_max)
61+
for shape, dtype in zip(self.shapes, self.dtypes):
62+
next_data = np.random.uniform(
63+
low=0, high=1000,
64+
size=(batch_size, ) + shape[1:]).astype(dtype)
65+
in_data.append(executor.as_lodtensor(next_data))
66+
67+
self.inputs.append(in_data)
68+
69+
executor.run(fluid.default_startup_program())
70+
self.outputs = []
71+
if use_thread:
72+
thread = Thread(
73+
target=feed_data, args=(feed_queue, self.inputs))
74+
thread.start()
75+
for in_data in self.inputs:
76+
self.outputs.append(
77+
executor.run(fetch_list=list(read_out_data)))
78+
else:
79+
for in_data in self.inputs:
80+
feed_queue.push(in_data)
81+
self.outputs.append(
82+
executor.run(fetch_list=list(read_out_data)))
83+
84+
feed_queue.close()
85+
self.validate()
86+
87+
def validate(self):
88+
self.assertEqual(len(self.inputs), len(self.outputs))
89+
for in_data_list, out_data_list in zip(self.inputs, self.outputs):
90+
self.assertEqual(len(in_data_list), len(out_data_list))
91+
in_data_list_np = [
92+
np.array(in_lod_tensor) for in_lod_tensor in in_data_list
93+
]
94+
for in_data, out_data in zip(in_data_list_np, out_data_list):
95+
self.assertTrue((in_data == out_data).all())
96+
97+
98+
if __name__ == '__main__':
99+
unittest.main()

0 commit comments

Comments
 (0)