Skip to content

Commit 61b3a59

Browse files
committed
Refine Python Reader
1 parent c204f0c commit 61b3a59

File tree

3 files changed

+23
-11
lines changed

3 files changed

+23
-11
lines changed

paddle/fluid/operators/reader/create_py_reader_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class PyReader : public framework::FileReader {
3333
if (!success) out->clear();
3434
}
3535

36+
~PyReader() { queue_->Close(); }
37+
3638
void Shutdown() override { queue_->Close(); }
3739

3840
void Start() override { queue_->ReOpen(); }

python/paddle/fluid/layers/io.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,7 @@ def feed_data(queue, feed_images, feed_labels):
558558
current_reset_method = reader.reset
559559
reader.thread = None
560560
reader.tensor_provider = None
561+
reader.exited = False
561562

562563
def start_provide_thread(func):
563564
def __provider_thread__():
@@ -571,17 +572,20 @@ def __provider_thread__():
571572

572573
array.append(item)
573574

575+
if reader.exited:
576+
break
574577
feed_queue.push(array)
578+
if reader.exited:
579+
break
575580
feed_queue.close()
576581

577582
reader.thread = threading.Thread(target=__provider_thread__)
578583
reader.thread.start()
579584

580585
def __set_tensor_provider__(func):
581-
reader._tensor_provider = func
582-
start_provide_thread(reader._tensor_provider)
586+
reader.tensor_provider = func
583587

584-
def __set_paddle_reader__(reader):
588+
def __set_paddle_reader__(paddle_reader):
585589
with program_guard(Program(), Program()):
586590
feed_list = []
587591
counter = 0
@@ -596,25 +600,29 @@ def __set_paddle_reader__(reader):
596600
counter += 1
597601

598602
feeder = DataFeeder(feed_list=feed_list, place=core.CPUPlace())
599-
600-
reader = feeder.decorate_reader(reader, multi_devices=False)
603+
paddle_reader = feeder.decorate_reader(
604+
paddle_reader, multi_devices=False)
601605

602606
def __tensor_provider__():
603-
for data in reader():
604-
yield [data[str(idx)] for idx in xrange(counter)]
607+
for slots in paddle_reader():
608+
yield [slots[str(idx)] for idx in xrange(counter)]
605609

606610
__set_tensor_provider__(__tensor_provider__)
607611

608612
def __reset__():
609613
current_reset_method()
610614
if reader.thread is not None and reader.tensor_provider is not None:
615+
reader.exited = True
611616
reader.thread.join()
612-
# restart provider thread.
613-
start_provide_thread(reader.tensor_provider)
617+
reader.exited = False
618+
619+
def __start__():
620+
start_provide_thread(reader.tensor_provider)
614621

615622
reader.reset = __reset__
616623
reader.decorate_tensor_provider = __set_tensor_provider__
617624
reader.decorate_paddle_reader = __set_paddle_reader__
625+
reader.start = __start__
618626

619627
return reader
620628

python/paddle/fluid/tests/demo/pyreader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,12 @@ def main():
6767

6868
train_reader.decorate_paddle_reader(
6969
paddle.v2.reader.shuffle(
70-
paddle.batch(mnist.train(), 256), buf_size=8192))
70+
paddle.batch(mnist.train(), 512), buf_size=8192))
7171

72-
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 256))
72+
test_reader.decorate_paddle_reader(paddle.batch(mnist.test(), 512))
7373

7474
for epoch_id in xrange(10):
75+
train_reader.start()
7576
try:
7677
while True:
7778
print 'train_loss', numpy.array(
@@ -80,6 +81,7 @@ def main():
8081
print 'End of epoch', epoch_id
8182
train_reader.reset()
8283

84+
test_reader.start()
8385
try:
8486
while True:
8587
print 'test loss', numpy.array(

0 commit comments

Comments
 (0)