Skip to content

Commit 5d96b6e

Browse files
authored
Add Queue.get delay for multiprocess data loader (#22604) (#22640)
1 parent 750c6f4 commit 5d96b6e

File tree

2 files changed

+69
-18
lines changed

2 files changed

+69
-18
lines changed

python/paddle/fluid/reader.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@
3434
import Queue as queue
3535
else:
3636
import queue
37-
# NOTE: [ avoid hanging ] This value is used in getting data from another process
38-
MP_CHECK_TIMEOUT = 10
37+
# NOTE: [ avoid hanging ] These value is used in getting data from another process
38+
QUEUE_GET_TIMEOUT = 5
39+
MAX_GET_FAILED_TIME = 12
3940

4041
__all__ = ['PyReader', 'DataLoader']
4142

@@ -485,6 +486,17 @@ def __handler__(signum, frame):
485486

486487
signal.signal(signal.SIGCHLD, __handler__)
487488

489+
def _exit_thread_expectedly(self):
490+
self._thread_done_event.set()
491+
self._blocking_queue.close()
492+
self._data_queue.close()
493+
494+
def _exit_thread_unexpectedly(self):
495+
self._thread_done_event.set()
496+
self._blocking_queue.kill()
497+
self._data_queue.close()
498+
logging.error("DataLoader reader thread raised an exception!")
499+
488500
def _reader_process_loop(self):
489501
try:
490502
# set signal handler
@@ -506,17 +518,29 @@ def _reader_process_loop(self):
506518
six.reraise(*sys.exc_info())
507519

508520
def _reader_thread_loop_with_process(self):
521+
get_sample_try_time = 0
509522
while not self._thread_done_event.is_set():
510523
try:
511524
# NOTE: [ avoid hanging ] Even with carefully designed data dependencies
512525
# (i.e., a put() always corresponding to a get()), hanging on get() can
513526
# still happen when data in queue is corrupted (e.g., due to
514527
# Queue.cancel_join_thread or unexpected exit). So we set a timeout whenever
515528
# we try to get data from `data_queue`
516-
sample = self._data_queue.get(timeout=MP_CHECK_TIMEOUT)
529+
sample = self._data_queue.get(timeout=QUEUE_GET_TIMEOUT)
530+
get_sample_try_time = 0
517531
except queue.Empty:
518-
self._thread_done_event.set()
519-
logging.error("The reader has not read data for a long time.")
532+
get_sample_try_time += 1
533+
if get_sample_try_time > MAX_GET_FAILED_TIME:
534+
self._exit_thread_unexpectedly()
535+
raise RuntimeError(
536+
"DataLoader reader thread has not read data for a long time (60s)."
537+
)
538+
else:
539+
# NOTE: [ avoid failed quickly ] Sometimes if the reader child process has a heavy burden,
540+
# the child process has no enough time to put the data in the queue when the main process
541+
# start trying to get data from queue. At this time, failure to read data should not be
542+
# counted as a fatal error, there should be a certain number of attempts.
543+
continue
520544

521545
if not self._thread_done_event.is_set():
522546
if sample is not None:
@@ -532,20 +556,10 @@ def _reader_thread_loop_with_process(self):
532556
if not self._blocking_queue.push(array):
533557
self._blocking_queue.close()
534558
except:
535-
self._thread_done_event.set()
536-
self._blocking_queue.kill()
537-
self._data_queue.close()
538-
logging.warning(
539-
"DygraphDataLoader reader thread raised an exception."
540-
)
559+
self._exit_thread_unexpectedly()
541560
six.reraise(*sys.exc_info())
542561
else:
543-
self._thread_done_event.set()
544-
self._blocking_queue.close()
545-
self._data_queue.close()
546-
else:
547-
self._blocking_queue.kill()
548-
self._data_queue.close()
562+
self._exit_thread_expectedly()
549563

550564
def _reader_thread_loop(self):
551565
try:

python/paddle/fluid/tests/unittests/test_imperative_data_loader_exception.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,26 @@
1313
# limitations under the License.
1414

1515
import sys
16+
import time
1617
import unittest
1718
import numpy as np
1819
import paddle.fluid as fluid
1920
from paddle.fluid import core
2021
import paddle.compat as cpt
2122

2223

24+
def get_random_images_and_labels(image_shape, label_shape):
25+
image = np.random.random(size=image_shape).astype('float32')
26+
label = np.random.random(size=label_shape).astype('int64')
27+
return image, label
28+
29+
2330
class TestDygraphhDataLoaderWithException(unittest.TestCase):
2431
def setUp(self):
32+
self.batch_size = 8
2533
self.batch_num = 4
26-
self.capacity = 2
34+
self.epoch_num = 1
35+
self.capacity = 5
2736

2837
def test_not_capacity(self):
2938
with fluid.dygraph.guard():
@@ -77,6 +86,34 @@ def __reader__():
7786
exception = ex
7887
self.assertIsNotNone(exception)
7988

89+
def test_multi_process_with_get_timeout(self):
90+
def slow_batch_generator_creator(batch_size, batch_num):
91+
def __reader__():
92+
for _ in range(batch_num):
93+
time.sleep(80)
94+
batch_image, batch_label = get_random_images_and_labels(
95+
[batch_size, 784], [batch_size, 1])
96+
yield batch_image, batch_label
97+
98+
return __reader__
99+
100+
with fluid.dygraph.guard():
101+
loader = fluid.io.DataLoader.from_generator(
102+
capacity=self.capacity, use_multiprocess=True)
103+
loader.set_batch_generator(
104+
slow_batch_generator_creator(self.batch_size, self.batch_num),
105+
places=fluid.CPUPlace())
106+
exception = None
107+
try:
108+
for _ in range(self.epoch_num):
109+
for image, _ in loader():
110+
fluid.layers.relu(image)
111+
except core.EnforceNotMet as ex:
112+
self.assertIn("Blocking queue is killed",
113+
cpt.get_exception_message(ex))
114+
exception = ex
115+
self.assertIsNotNone(exception)
116+
80117

81118
if __name__ == '__main__':
82119
unittest.main()

0 commit comments

Comments
 (0)