Skip to content

Commit cfcdf6e

Browse files
NivekTfacebook-github-bot
authored andcommitted
Adding 'pause' and 'resume' operations to halt DataPipes (#879)
Summary: Pull Request resolved: #879 The goal of this PR is fully stop DataPipe activities in preparation of snapshotting (which requires a halted state), so buffers will not be changing while we record the snapshot. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D41744759 Pulled By: NivekT fbshipit-source-id: 7045a23ea2fa7499184b50c09f706f46472e88ec
1 parent e908330 commit cfcdf6e

File tree

9 files changed

+629
-58
lines changed

9 files changed

+629
-58
lines changed

test/dataloader2/test_dataloader2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,7 @@ def _collect_data(self, datapipe, reading_service_gen):
252252
result.append(row)
253253
for row in dl:
254254
result.append(row)
255+
dl.shutdown()
255256
return result
256257

257258
@staticmethod
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import multiprocessing as mp
9+
import unittest
10+
from unittest import TestCase
11+
12+
from torch.testing._internal.common_utils import instantiate_parametrized_tests, parametrize
13+
from torchdata.dataloader2 import DataLoader2, DataLoader2Iterator, PrototypeMultiProcessingReadingService
14+
from torchdata.datapipes.iter import IterableWrapper
15+
16+
17+
def _add_one(x: int) -> int:
18+
return x + 1
19+
20+
21+
# Test DataPipes
22+
n_elements = 10
23+
dp1 = IterableWrapper(range(n_elements)).shuffle().sharding_filter()
24+
double_pause_dp = dp1.prefetch().prefetch()
25+
test_dps = [dp1, double_pause_dp]
26+
27+
28+
mp_ctx_parametrize = parametrize("ctx", mp.get_all_start_methods())
29+
dp_parametrize = parametrize("dp", test_dps)
30+
31+
32+
class TestPrototypeMultiProcessingReadingService(TestCase):
33+
r"""
34+
This tests specific functionalities of PrototypeMultiProcessingReadingService, notably
35+
`pause`, `resume`, `snapshot`.
36+
"""
37+
38+
@mp_ctx_parametrize
39+
def test_reading_service_pause_resume_0_worker(self, ctx) -> None:
40+
41+
# Functional Test: Verifies that this ReadingService will raise error when `pause/resume` is used
42+
# with `num_workers = 0`
43+
rs0 = PrototypeMultiProcessingReadingService(
44+
num_workers=0, worker_prefetch_cnt=0, main_prefetch_cnt=0, multiprocessing_context=ctx
45+
)
46+
dl0: DataLoader2 = DataLoader2(dp1, reading_service=rs0)
47+
res0 = []
48+
for i, x in enumerate(dl0):
49+
res0.append(x)
50+
if i in {2}:
51+
with self.assertRaisesRegex(RuntimeError, r"pause"):
52+
dl0._pause()
53+
with self.assertRaisesRegex(RuntimeError, r"resume"):
54+
dl0._resume()
55+
dl0.shutdown()
56+
57+
@mp_ctx_parametrize
58+
@dp_parametrize
59+
@parametrize(
60+
"n_workers,worker_prefetch_cnt,main_prefetch_cnt",
61+
[(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 0), (2, 0, 2), (2, 2, 2)],
62+
)
63+
def test_reading_service_pause_resume(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
64+
65+
# Functional Test: Testing various configuration of DataPipe/ReadingService to ensure the pipeline
66+
# properly pauses and resumes
67+
rs = PrototypeMultiProcessingReadingService(
68+
num_workers=n_workers,
69+
worker_prefetch_cnt=worker_prefetch_cnt,
70+
main_prefetch_cnt=main_prefetch_cnt,
71+
multiprocessing_context=ctx,
72+
)
73+
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
74+
res = []
75+
for i, x in enumerate(dl):
76+
res.append(x)
77+
if i in {2, n_elements - 2}:
78+
dl._pause()
79+
dl._resume()
80+
81+
self.assertEqual(
82+
list(range(n_elements)),
83+
sorted(res),
84+
msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, "
85+
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, "
86+
f"main_prefetch_cnt = {rs.main_prefetch_cnt}",
87+
)
88+
dl.shutdown()
89+
90+
@mp_ctx_parametrize
91+
@dp_parametrize
92+
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(2, 0, 1), (2, 1, 0), (2, 0, 0)])
93+
def test_reading_service_pause_stop_yield(self, ctx, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
94+
95+
# Functional Test: Confirms that `dl` will stop yielding elements after `_pause` is called
96+
rs = PrototypeMultiProcessingReadingService(
97+
num_workers=n_workers,
98+
worker_prefetch_cnt=worker_prefetch_cnt,
99+
main_prefetch_cnt=main_prefetch_cnt,
100+
multiprocessing_context=ctx,
101+
)
102+
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
103+
res = []
104+
for i, x in enumerate(dl):
105+
res.append(x)
106+
if i in {2}:
107+
dl._pause()
108+
self.assertEqual(
109+
3,
110+
len(res),
111+
msg=f"The test is failing with '{ctx}', num_workers = {rs.num_workers}, "
112+
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
113+
)
114+
dl.shutdown()
115+
116+
@dp_parametrize
117+
@parametrize("n_workers,worker_prefetch_cnt,main_prefetch_cnt", [(1, 0, 0), (1, 0, 2), (2, 0, 0), (2, 2, 2)])
118+
def test_reading_service_limit(self, dp, n_workers, worker_prefetch_cnt, main_prefetch_cnt) -> None:
119+
120+
rs = PrototypeMultiProcessingReadingService(
121+
num_workers=n_workers, worker_prefetch_cnt=worker_prefetch_cnt, main_prefetch_cnt=main_prefetch_cnt
122+
)
123+
124+
dl: DataLoader2 = DataLoader2(dp, reading_service=rs)
125+
res = []
126+
cumulative_res = []
127+
n_limit = 3
128+
129+
it: DataLoader2Iterator = iter(dl)
130+
it.limit(n_limit)
131+
for x in it:
132+
res.append(x)
133+
# Functional Test: Verify that the number of elements yielded equals to the specified limit
134+
self.assertEqual(
135+
n_limit,
136+
len(res), # 3
137+
msg=f"The test is failing with default multiprocessing method, "
138+
f"num_workers = {rs.num_workers}, "
139+
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
140+
)
141+
cumulative_res.extend(res)
142+
143+
# Functional Test: Calling `next` after `limit` will trigger `StopIteration`
144+
with self.assertRaises(StopIteration):
145+
next(it)
146+
147+
# Functional Test: Verify that `limit` persists without the need to set it again
148+
it.resume()
149+
res = []
150+
for x in it:
151+
res.append(x)
152+
self.assertEqual(
153+
n_limit,
154+
len(res), # 3
155+
msg=f"The test is failing with default multiprocessing method, "
156+
f"num_workers = {rs.num_workers}, "
157+
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
158+
)
159+
cumulative_res.extend(res)
160+
161+
# Functional Test: Clear the `limit` and yield the rest of the elements
162+
it.limit(None)
163+
it.resume()
164+
res = []
165+
for x in it:
166+
res.append(x)
167+
self.assertEqual(
168+
n_elements - 2 * n_limit,
169+
len(res), # 4
170+
msg=f"The test is failing with default multiprocessing method, "
171+
f"num_workers = {rs.num_workers}, "
172+
f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
173+
)
174+
175+
cumulative_res.extend(res)
176+
self.assertEqual(list(range(n_elements)), sorted(cumulative_res))
177+
178+
# Functional Test: Setting `limit` to a different value during after each mini-epoch
179+
dl2: DataLoader2 = DataLoader2(double_pause_dp, reading_service=rs)
180+
res = []
181+
it2: DataLoader2Iterator = iter(dl2)
182+
it2.limit(3)
183+
for x in it2:
184+
res.append(x)
185+
186+
# Limit can be set before `resume`
187+
it2.limit(4)
188+
it2.resume()
189+
for x in it2:
190+
res.append(x)
191+
self.assertEqual(7, len(res))
192+
193+
# Limit can also be set after `resume`, but before the next `for` loop
194+
it2.resume()
195+
it2.limit(2)
196+
for x in it2:
197+
res.append(x)
198+
self.assertEqual(9, len(res))
199+
200+
# TODO: Test cases when there is official support of `pause` and `resume` with round-robin sharding
201+
# Currently, using sharding_round_robin raises a warning
202+
# def test_round_robin_dispatching_pause_limit(self):
203+
# source_dp = IterableWrapper(range(20))
204+
# dp = source_dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
205+
# dp = dp.map(_add_one)
206+
207+
# TODO: This doesn't work with `num_workers > 1`
208+
# TODO: Try checking if `dp_list`'s elements are _IterateQueueDP or QueueWrapper, we can safely assume
209+
# those DPs belong to a dispatching process and only do pause if worker_id == 0
210+
# There might still be a race condition, need to look into the messages
211+
212+
# rs1 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=0)
213+
# rs2 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=0, main_prefetch_cnt=2)
214+
# rs3 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=0)
215+
# rs4 = PrototypeMultiProcessingReadingService(num_workers=2, worker_prefetch_cnt=2, main_prefetch_cnt=2)
216+
# rss = [rs1, rs2, rs3, rs4]
217+
218+
# for n, rs in enumerate(rss):
219+
# dl = DataLoader2(dp, reading_service=rs)
220+
# res = []
221+
# # cumulative_res = []
222+
# n_limit = 3
223+
#
224+
# it: DataLoader2Iterator = iter(dl)
225+
# it.limit(n_limit) # The `pause` call here doesn't stop
226+
# for x in it:
227+
# res.append(x)
228+
#
229+
# print()
230+
# print(res)
231+
#
232+
# dl.shutdown()
233+
234+
# # Functional Test: Verify that the number of elements yielded equals to the specified limit
235+
# # self.assertEqual(
236+
# # n_limit,
237+
# # len(res), # 3
238+
# # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, "
239+
# # f"num_workers = {rs.num_workers}, "
240+
# # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
241+
# # )
242+
# cumulative_res.extend(res)
243+
#
244+
# # Functional Test: Calling `next` after `limit` will trigger `StopIteration`
245+
# with self.assertRaisesRegex(StopIteration, "pause"):
246+
# next(it)
247+
#
248+
# # Functional Test: Verify that `limit` persists without the need to set it again
249+
# it.resume()
250+
# res = []
251+
# for x in it:
252+
# res.append(x)
253+
# # self.assertEqual(
254+
# # n_limit,
255+
# # len(res), # 3
256+
# # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, "
257+
# # f"num_workers = {rs.num_workers}, "
258+
# # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
259+
# # )
260+
# cumulative_res.extend(res)
261+
#
262+
# # Functional Test: Clear the `limit` and yield the rest of the elements
263+
# it.limit(None)
264+
# it.resume()
265+
# res = []
266+
# for x in it:
267+
# res.append(x)
268+
# # self.assertEqual(
269+
# # n_elements - 2 * n_limit,
270+
# # len(res), # 4
271+
# # msg=f"The test is failing for rs{n + 1} with default multiprocessing method, "
272+
# # f"num_workers = {rs.num_workers}, "
273+
# # f"worker_prefetch_cnt = {rs.worker_prefetch_cnt}, main_prefetch_cnt = {rs.main_prefetch_cnt}",
274+
# # )
275+
#
276+
# cumulative_res.extend(res)
277+
# self.assertEqual(list(range(n_elements)), sorted(cumulative_res))
278+
279+
# TODO: Implemented in an upcoming PR
280+
# def test_reading_service_snapshot(self) -> None:
281+
# pass
282+
#
283+
# def test_dataloader2_snapshot(self) -> None:
284+
# pass
285+
286+
287+
instantiate_parametrized_tests(TestPrototypeMultiProcessingReadingService)
288+
289+
290+
if __name__ == "__main__":
291+
unittest.main()

0 commit comments

Comments
 (0)