Skip to content

Commit cd67cf3

Browse files
authored
Gracefully cancel async tasks (#7414)
gracefully cancel async tasks
1 parent b7fb17e commit cd67cf3

File tree

2 files changed

+64
-33
lines changed

2 files changed

+64
-33
lines changed

src/datasets/arrow_dataset.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3433,14 +3433,19 @@ def init_buffer_and_writer():
34333433
)
34343434
return buf_writer, writer, tmp_file
34353435

3436+
tasks: List[asyncio.Task] = []
3437+
if inspect.iscoroutinefunction(function):
3438+
try:
3439+
loop = asyncio.get_running_loop()
3440+
except RuntimeError:
3441+
loop = asyncio.new_event_loop()
3442+
else:
3443+
loop = None
3444+
34363445
def iter_outputs(shard_iterable):
3446+
nonlocal tasks, loop
34373447
if inspect.iscoroutinefunction(function):
34383448
indices: Union[List[int], List[List[int]]] = []
3439-
tasks: List[asyncio.Task] = []
3440-
try:
3441-
loop = asyncio.get_running_loop()
3442-
except RuntimeError:
3443-
loop = asyncio.new_event_loop()
34443449
for i, example in shard_iterable:
34453450
indices.append(i)
34463451
tasks.append(loop.create_task(async_apply_function(example, i, offset=offset)))
@@ -3457,7 +3462,8 @@ def iter_outputs(shard_iterable):
34573462
while tasks and tasks[0].done():
34583463
yield indices.pop(0), tasks.pop(0).result()
34593464
while tasks:
3460-
yield indices.pop(0), loop.run_until_complete(tasks.pop(0))
3465+
yield indices[0], loop.run_until_complete(tasks[0])
3466+
indices.pop(0), tasks.pop(0)
34613467
else:
34623468
for i, example in shard_iterable:
34633469
yield i, apply_function(example, i, offset=offset)
@@ -3542,6 +3548,14 @@ def iter_outputs(shard_iterable):
35423548
tmp_file.close()
35433549
if os.path.exists(tmp_file.name):
35443550
os.remove(tmp_file.name)
3551+
if loop:
3552+
logger.debug(f"Canceling {len(tasks)} async tasks.")
3553+
for task in tasks:
3554+
task.cancel(msg="KeyboardInterrupt")
3555+
try:
3556+
loop.run_until_complete(asyncio.gather(*tasks))
3557+
except asyncio.CancelledError:
3558+
logger.debug("Tasks canceled.")
35453559
raise
35463560

35473561
yield rank, False, num_examples_progress_update

src/datasets/iterable_dataset.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,15 +1171,20 @@ async def async_apply_function(key_example, indices):
11711171
processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs)
11721172
return prepare_outputs(key_example, inputs, processed_inputs)
11731173

1174+
tasks: List[asyncio.Task] = []
1175+
if inspect.iscoroutinefunction(self.function):
1176+
try:
1177+
loop = asyncio.get_running_loop()
1178+
except RuntimeError:
1179+
loop = asyncio.new_event_loop()
1180+
else:
1181+
loop = None
1182+
11741183
def iter_outputs():
1184+
nonlocal tasks, loop
11751185
inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs()
11761186
if inspect.iscoroutinefunction(self.function):
11771187
indices: Union[List[int], List[List[int]]] = []
1178-
tasks: List[asyncio.Task] = []
1179-
try:
1180-
loop = asyncio.get_running_loop()
1181-
except RuntimeError:
1182-
loop = asyncio.new_event_loop()
11831188
for i, key_example in inputs_iterator:
11841189
indices.append(i)
11851190
tasks.append(loop.create_task(async_apply_function(key_example, i)))
@@ -1196,36 +1201,48 @@ def iter_outputs():
11961201
while tasks and tasks[0].done():
11971202
yield indices.pop(0), tasks.pop(0).result()
11981203
while tasks:
1199-
yield indices.pop(0), loop.run_until_complete(tasks.pop(0))
1204+
yield indices[0], loop.run_until_complete(tasks[0])
1205+
indices.pop(0), tasks.pop(0)
12001206
else:
12011207
for i, key_example in inputs_iterator:
12021208
yield i, apply_function(key_example, i)
12031209

1204-
if self.batched:
1205-
if self._state_dict:
1206-
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
1207-
self._state_dict["num_examples_since_previous_state"] = 0
1208-
self._state_dict["previous_state_example_idx"] = current_idx
1209-
for key, transformed_batch in iter_outputs():
1210-
# yield one example at a time from the transformed batch
1211-
for example in _batch_to_examples(transformed_batch):
1212-
current_idx += 1
1213-
if self._state_dict:
1214-
self._state_dict["num_examples_since_previous_state"] += 1
1215-
if num_examples_to_skip > 0:
1216-
num_examples_to_skip -= 1
1217-
continue
1218-
yield key, example
1210+
try:
1211+
if self.batched:
12191212
if self._state_dict:
12201213
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
12211214
self._state_dict["num_examples_since_previous_state"] = 0
12221215
self._state_dict["previous_state_example_idx"] = current_idx
1223-
else:
1224-
for key, transformed_example in iter_outputs():
1225-
current_idx += 1
1226-
if self._state_dict:
1227-
self._state_dict["previous_state_example_idx"] += 1
1228-
yield key, transformed_example
1216+
for key, transformed_batch in iter_outputs():
1217+
# yield one example at a time from the transformed batch
1218+
for example in _batch_to_examples(transformed_batch):
1219+
current_idx += 1
1220+
if self._state_dict:
1221+
self._state_dict["num_examples_since_previous_state"] += 1
1222+
if num_examples_to_skip > 0:
1223+
num_examples_to_skip -= 1
1224+
continue
1225+
yield key, example
1226+
if self._state_dict:
1227+
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
1228+
self._state_dict["num_examples_since_previous_state"] = 0
1229+
self._state_dict["previous_state_example_idx"] = current_idx
1230+
else:
1231+
for key, transformed_example in iter_outputs():
1232+
current_idx += 1
1233+
if self._state_dict:
1234+
self._state_dict["previous_state_example_idx"] += 1
1235+
yield key, transformed_example
1236+
except (Exception, KeyboardInterrupt):
1237+
if loop:
1238+
logger.debug(f"Canceling {len(tasks)} async tasks.")
1239+
for task in tasks:
1240+
task.cancel(msg="KeyboardInterrupt")
1241+
try:
1242+
loop.run_until_complete(asyncio.gather(*tasks))
1243+
except asyncio.CancelledError:
1244+
logger.debug("Tasks canceled.")
1245+
raise
12291246

12301247
def _iter_arrow(self, max_chunksize: Optional[int] = None) -> Iterator[Tuple[Key, pa.Table]]:
12311248
formatter: TableFormatter = get_formatter(self.formatting.format_type) if self.formatting else ArrowFormatter()

0 commit comments

Comments
 (0)