Skip to content

Commit b681537

Browse files
authored
Add multiprocess reader (#13311)
* add multiprocess_reader * add multiprocess_reader to reader decorator * support piped multi process reader * revert v2 decorator * add comment to multiprocess_reader * optimize code * use ujson to speed up json serialize/deserialize * add assert to multiprocess_reader * update comment of multiprocess_reader * optimize ujson import, handle error case * optimize import ujson * remove ujson from requirements.txt * add import sys to decorator.py
1 parent b4dd5c2 commit b681537

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed

python/paddle/reader/decorator.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
__all__ = [
1616
'map_readers', 'buffered', 'compose', 'chain', 'shuffle',
17-
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader'
17+
'ComposeNotAligned', 'firstn', 'xmap_readers', 'PipeReader',
18+
'multiprocess_reader'
1819
]
1920

2021
from threading import Thread
2122
import subprocess
23+
import multiprocessing
24+
import sys
2225

2326
from six.moves.queue import Queue
2427
from six.moves import zip_longest
@@ -332,6 +335,100 @@ def xreader():
332335
return xreader
333336

334337

338+
def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
339+
"""
340+
multiprocess_reader use python multi process to read data from readers
341+
and then use multiprocess.Queue or multiprocess.Pipe to merge all
342+
data. The process number is equal to the number of input readers, each
343+
process call one reader.
344+
345+
Multiprocess.Queue require the rw access right to /dev/shm, some
346+
platform does not support.
347+
348+
you need to create multiple readers first, these readers should be independent
349+
to each other so that each process can work independently.
350+
351+
An example:
352+
353+
.. code-block:: python
354+
355+
reader0 = reader(["file01", "file02"])
356+
reader1 = reader(["file11", "file12"])
357+
reader1 = reader(["file21", "file22"])
358+
reader = multiprocess_reader([reader0, reader1, reader2],
359+
queue_size=100, use_pipe=False)
360+
"""
361+
362+
try:
363+
import ujson as json
364+
except Exception as e:
365+
sys.stderr.write("import ujson error: " + str(e) + " use json\n")
366+
import json
367+
368+
assert type(readers) is list and len(readers) > 0
369+
370+
def _read_into_queue(reader, queue):
371+
for sample in reader():
372+
if sample is None:
373+
raise ValueError("sample has None")
374+
queue.put(sample)
375+
queue.put(None)
376+
377+
def queue_reader():
378+
queue = multiprocessing.Queue(queue_size)
379+
for reader in readers:
380+
p = multiprocessing.Process(
381+
target=_read_into_queue, args=(reader, queue))
382+
p.start()
383+
384+
reader_num = len(readers)
385+
finish_num = 0
386+
while finish_num < reader_num:
387+
sample = queue.get()
388+
if sample is None:
389+
finish_num += 1
390+
else:
391+
yield sample
392+
393+
def _read_into_pipe(reader, conn):
394+
for sample in reader():
395+
if sample is None:
396+
raise ValueError("sample has None!")
397+
conn.send(json.dumps(sample))
398+
conn.send(json.dumps(None))
399+
conn.close()
400+
401+
def pipe_reader():
402+
conns = []
403+
for reader in readers:
404+
parent_conn, child_conn = multiprocessing.Pipe()
405+
conns.append(parent_conn)
406+
p = multiprocessing.Process(
407+
target=_read_into_pipe, args=(reader, child_conn))
408+
p.start()
409+
410+
reader_num = len(readers)
411+
finish_num = 0
412+
conn_to_remove = []
413+
while finish_num < reader_num:
414+
for conn in conn_to_remove:
415+
conns.remove(conn)
416+
conn_to_remove = []
417+
for conn in conns:
418+
sample = json.loads(conn.recv())
419+
if sample is None:
420+
finish_num += 1
421+
conn.close()
422+
conn_to_remove.append(conn)
423+
else:
424+
yield sample
425+
426+
if use_pipe:
427+
return pipe_reader
428+
else:
429+
return queue_reader
430+
431+
335432
def _buf2lines(buf, line_break="\n"):
336433
# FIXME: line_break should be automatically configured.
337434
lines = buf.split(line_break)

python/paddle/reader/tests/decorator_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import time
1616
import unittest
17+
import functools
1718

1819
import paddle.reader
1920

@@ -174,5 +175,33 @@ def example_reader(myfiles):
174175
temp.close()
175176

176177

178+
class TestMultiProcessReader(unittest.TestCase):
179+
def setup(self):
180+
self.samples = []
181+
for i in range(1000):
182+
self.samples.append([[i], [i + 1, i + 2], i + 3])
183+
184+
def reader(index):
185+
for i in range(len(self.samples)):
186+
if i % 3 == index:
187+
yield self.samples[i]
188+
189+
self.reader0 = functools.partial(reader, 0)
190+
self.reader1 = functools.partial(reader, 1)
191+
self.reader2 = functools.partial(reader, 2)
192+
193+
def reader_test(self, use_pipe):
194+
self.setup()
195+
results = []
196+
for data in paddle.reader.multiprocess_reader(
197+
[self.reader0, self.reader1, self.reader2], 100, use_pipe)():
198+
results.append(data)
199+
self.assertEqual(sorted(self.samples), sorted(results))
200+
201+
def test_multi_process_reader(self):
202+
self.reader_test(use_pipe=False)
203+
self.reader_test(use_pipe=True)
204+
205+
177206
if __name__ == '__main__':
178207
unittest.main()

0 commit comments

Comments
 (0)