Skip to content

Commit 9ff5184

Browse files
authored
Merge pull request #13732 from jacquesqiao/add-fake-reader
add a fake reader for speed test
2 parents 943e4de + 91e8299 commit 9ff5184

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

python/paddle/reader/decorator.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
1717
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
18-
'multiprocess_reader'
18+
'multiprocess_reader', 'Fake'
1919
]
2020

2121
from threading import Thread
@@ -504,3 +504,39 @@ def get_line(self, cut_lines=True, line_break="\n"):
504504
yield decomp_buff
505505
else:
506506
break
507+
508+
509+
class Fake(object):
510+
"""
511+
fake reader will cache the first data it read and yield it out for data_num times.
512+
It is used to cache a data from real reader and use it for speed testing.
513+
514+
:param reader: the origin reader
515+
:param data_num: times that this reader will yield data.
516+
517+
:return: a fake reader.
518+
519+
Examples:
520+
.. code-block:: python
521+
522+
def reader():
523+
for i in range(10):
524+
yield i
525+
526+
fake_reader = Fake()(reader, 100)
527+
"""
528+
529+
def __init__(self):
530+
self.data = None
531+
self.yield_num = 0
532+
533+
def __call__(self, reader, data_num):
534+
def fake_reader():
535+
if self.data is None:
536+
self.data = next(reader())
537+
while self.yield_num < data_num:
538+
yield self.data
539+
self.yield_num += 1
540+
self.yield_num = 0
541+
542+
return fake_reader

python/paddle/reader/tests/decorator_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,21 @@ def test_multi_process_reader(self):
203203
self.reader_test(use_pipe=True)
204204

205205

206+
class TestFakeReader(unittest.TestCase):
207+
def test_fake_reader(self):
208+
def reader():
209+
for i in range(10):
210+
yield i
211+
212+
data_num = 100
213+
fake_reader = paddle.reader.Fake()(reader, data_num)
214+
for _ in range(10):
215+
i = 0
216+
for data in fake_reader():
217+
self.assertEqual(data, 0)
218+
i += 1
219+
self.assertEqual(i, data_num)
220+
221+
206222
if __name__ == '__main__':
207223
unittest.main()

0 commit comments

Comments
 (0)