Skip to content

Commit c204f0c

Browse files
committed
Refine PyReader
1 parent 6a46c07 commit c204f0c

File tree

2 files changed

+82
-46
lines changed

2 files changed

+82
-46
lines changed

python/paddle/fluid/layers/io.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
# limitations under the License.
1414
import contextlib
1515
import multiprocessing
16+
import threading
1617

18+
from ..data_feeder import DataFeeder
1719
from control_flow import BlockGuard
1820
from layer_function_generator import templatedoc
1921
from .. import core
2022
from ..executor import global_scope
2123
from ..framework import convert_np_dtype_to_dtype_, default_main_program, \
22-
default_startup_program
24+
default_startup_program, program_guard, Program
2325
from ..layer_helper import LayerHelper
2426
from ..unique_name import generate as unique_name
2527

@@ -550,7 +552,71 @@ def feed_data(queue, feed_images, feed_labels):
550552
# py_reader.
551553
double_buffer_reader.reset = reader.reset
552554
reader = double_buffer_reader
553-
return reader, feed_queue
555+
556+
# monkey patch py_reader special methods
557+
reader.queue = feed_queue
558+
current_reset_method = reader.reset
559+
reader.thread = None
560+
reader.tensor_provider = None
561+
562+
def start_provide_thread(func):
563+
def __provider_thread__():
564+
for tensors in func():
565+
array = core.LoDTensorArray()
566+
for item in tensors:
567+
if not isinstance(item, core.LoDTensor):
568+
tmp = core.LoDTensor()
569+
tmp.set(item, core.CPUPlace())
570+
item = tmp
571+
572+
array.append(item)
573+
574+
feed_queue.push(array)
575+
feed_queue.close()
576+
577+
reader.thread = threading.Thread(target=__provider_thread__)
578+
reader.thread.start()
579+
580+
def __set_tensor_provider__(func):
581+
reader._tensor_provider = func
582+
start_provide_thread(reader._tensor_provider)
583+
584+
def __set_paddle_reader__(reader):
585+
with program_guard(Program(), Program()):
586+
feed_list = []
587+
counter = 0
588+
for dtype, shape, lod_level in zip(dtypes, shapes, lod_levels):
589+
name = str(counter)
590+
feed_list.append(
591+
data(
592+
name=name,
593+
dtype=dtype,
594+
shape=shape,
595+
lod_level=lod_level))
596+
counter += 1
597+
598+
feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace())
599+
600+
reader = feeder.decorate_reader(reader, multi_devices=False)
601+
602+
def __tensor_provider__():
603+
for data in reader():
604+
yield [data[str(idx)] for idx in xrange(counter)]
605+
606+
__set_tensor_provider__(__tensor_provider__)
607+
608+
def __reset__():
609+
current_reset_method()
610+
if reader.thread is not None and reader.tensor_provider is not None:
611+
reader.thread.join()
612+
# restart provider thread.
613+
start_provide_thread(reader.tensor_provider)
614+
615+
reader.reset = __reset__
616+
reader.decorate_tensor_provider = __set_tensor_provider__
617+
reader.decorate_paddle_reader = __set_paddle_reader__
618+
619+
return reader
554620

555621

556622
def open_files(filenames,

python/paddle/fluid/tests/demo/pyreader.py

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import paddle.fluid as fluid
16-
import paddle.dataset.mnist as mnist
15+
import numpy
16+
1717
import paddle
18+
import paddle.dataset.mnist as mnist
19+
import paddle.fluid as fluid
1820
import paddle.v2
19-
import threading
20-
import numpy
2121

2222

2323
def network(is_train):
24-
reader, queue = fluid.layers.py_reader(
24+
reader = fluid.layers.py_reader(
2525
capacity=10,
2626
shapes=((-1, 784), (-1, 1)),
2727
dtypes=('float32', 'int64'),
@@ -37,32 +37,7 @@ def network(is_train):
3737

3838
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
3939
loss = fluid.layers.cross_entropy(input=prediction, label=label)
40-
return fluid.layers.mean(loss), queue, reader
41-
42-
43-
def pipe_reader_to_queue(reader_creator, queue):
44-
with fluid.program_guard(fluid.Program(), fluid.Program()):
45-
feeder = fluid.DataFeeder(
46-
feed_list=[
47-
fluid.layers.data(
48-
name='img', dtype='float32', shape=[784]),
49-
fluid.layers.data(
50-
name='label', dtype='int64', shape=[1])
51-
],
52-
place=fluid.CPUPlace())
53-
54-
def __thread_main__():
55-
for data in feeder.decorate_reader(
56-
reader_creator, multi_devices=False)():
57-
tmp = fluid.core.LoDTensorArray()
58-
tmp.append(data['img'])
59-
tmp.append(data['label'])
60-
queue.push(tmp)
61-
queue.close()
62-
63-
th = threading.Thread(target=__thread_main__)
64-
th.start()
65-
return th
40+
return fluid.layers.mean(loss), reader
6641

6742

6843
def main():
@@ -71,15 +46,15 @@ def main():
7146

7247
with fluid.program_guard(train_prog, startup_prog):
7348
with fluid.unique_name.guard():
74-
loss, train_queue, train_reader = network(True)
49+
loss, train_reader = network(True)
7550
adam = fluid.optimizer.Adam(learning_rate=0.01)
7651
adam.minimize(loss)
7752

7853
test_prog = fluid.Program()
7954
test_startup = fluid.Program()
8055
with fluid.program_guard(test_prog, test_startup):
8156
with fluid.unique_name.guard():
82-
test_loss, test_queue, test_reader = network(False)
57+
test_loss, test_reader = network(False)
8358

8459
fluid.Executor(fluid.CUDAPlace(0)).run(startup_prog)
8560
fluid.Executor(fluid.CUDAPlace(0)).run(test_startup)
@@ -90,21 +65,21 @@ def main():
9065
tester = fluid.ParallelExecutor(
9166
use_cuda=True, share_vars_from=trainer, main_program=test_prog)
9267

68+
train_reader.decorate_paddle_reader(
69+
paddle.v2.reader.shuffle(
70+
paddle.batch(mnist.train(), 256), buf_size=8192))
71+
72+
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 256))
73+
9374
for epoch_id in xrange(10):
94-
train_data_thread = pipe_reader_to_queue(
95-
paddle.batch(paddle.v2.reader.firstn(mnist.train(), 32), 64),
96-
train_queue)
9775
try:
9876
while True:
9977
print 'train_loss', numpy.array(
10078
trainer.run(fetch_list=[loss.name]))
10179
except fluid.core.EOFException:
10280
print 'End of epoch', epoch_id
10381
train_reader.reset()
104-
train_data_thread.join()
10582

106-
test_data_thread = pipe_reader_to_queue(
107-
paddle.batch(mnist.test(), 32), test_queue)
10883
try:
10984
while True:
11085
print 'test loss', numpy.array(
@@ -113,11 +88,6 @@ def main():
11388
print 'End of testing'
11489
test_reader.reset()
11590

116-
test_data_thread.join()
117-
break
118-
del trainer
119-
del tester
120-
12191

12292
if __name__ == '__main__':
12393
main()

0 commit comments

Comments
 (0)