@@ -1171,15 +1171,20 @@ async def async_apply_function(key_example, indices):
11711171 processed_inputs = await self .function (* fn_args , * additional_args , ** fn_kwargs )
11721172 return prepare_outputs (key_example , inputs , processed_inputs )
11731173
1174+ tasks : List [asyncio .Task ] = []
1175+ if inspect .iscoroutinefunction (self .function ):
1176+ try :
1177+ loop = asyncio .get_running_loop ()
1178+ except RuntimeError :
1179+ loop = asyncio .new_event_loop ()
1180+ else :
1181+ loop = None
1182+
11741183 def iter_outputs ():
1184+ nonlocal tasks , loop
11751185 inputs_iterator = iter_batched_inputs () if self .batched else iter_inputs ()
11761186 if inspect .iscoroutinefunction (self .function ):
11771187 indices : Union [List [int ], List [List [int ]]] = []
1178- tasks : List [asyncio .Task ] = []
1179- try :
1180- loop = asyncio .get_running_loop ()
1181- except RuntimeError :
1182- loop = asyncio .new_event_loop ()
11831188 for i , key_example in inputs_iterator :
11841189 indices .append (i )
11851190 tasks .append (loop .create_task (async_apply_function (key_example , i )))
@@ -1196,36 +1201,48 @@ def iter_outputs():
11961201 while tasks and tasks [0 ].done ():
11971202 yield indices .pop (0 ), tasks .pop (0 ).result ()
11981203 while tasks :
1199- yield indices .pop (0 ), loop .run_until_complete (tasks .pop (0 ))
1204+ yield indices [0 ], loop .run_until_complete (tasks [0 ])
1205+ indices .pop (0 ), tasks .pop (0 )
12001206 else :
12011207 for i , key_example in inputs_iterator :
12021208 yield i , apply_function (key_example , i )
12031209
1204- if self .batched :
1205- if self ._state_dict :
1206- self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1207- self ._state_dict ["num_examples_since_previous_state" ] = 0
1208- self ._state_dict ["previous_state_example_idx" ] = current_idx
1209- for key , transformed_batch in iter_outputs ():
1210- # yield one example at a time from the transformed batch
1211- for example in _batch_to_examples (transformed_batch ):
1212- current_idx += 1
1213- if self ._state_dict :
1214- self ._state_dict ["num_examples_since_previous_state" ] += 1
1215- if num_examples_to_skip > 0 :
1216- num_examples_to_skip -= 1
1217- continue
1218- yield key , example
1210+ try :
1211+ if self .batched :
12191212 if self ._state_dict :
12201213 self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
12211214 self ._state_dict ["num_examples_since_previous_state" ] = 0
12221215 self ._state_dict ["previous_state_example_idx" ] = current_idx
1223- else :
1224- for key , transformed_example in iter_outputs ():
1225- current_idx += 1
1226- if self ._state_dict :
1227- self ._state_dict ["previous_state_example_idx" ] += 1
1228- yield key , transformed_example
1216+ for key , transformed_batch in iter_outputs ():
1217+ # yield one example at a time from the transformed batch
1218+ for example in _batch_to_examples (transformed_batch ):
1219+ current_idx += 1
1220+ if self ._state_dict :
1221+ self ._state_dict ["num_examples_since_previous_state" ] += 1
1222+ if num_examples_to_skip > 0 :
1223+ num_examples_to_skip -= 1
1224+ continue
1225+ yield key , example
1226+ if self ._state_dict :
1227+ self ._state_dict ["previous_state" ] = self .ex_iterable .state_dict ()
1228+ self ._state_dict ["num_examples_since_previous_state" ] = 0
1229+ self ._state_dict ["previous_state_example_idx" ] = current_idx
1230+ else :
1231+ for key , transformed_example in iter_outputs ():
1232+ current_idx += 1
1233+ if self ._state_dict :
1234+ self ._state_dict ["previous_state_example_idx" ] += 1
1235+ yield key , transformed_example
1236+ except (Exception , KeyboardInterrupt ):
1237+ if loop :
1238+ logger .debug (f"Canceling { len (tasks )} async tasks." )
1239+ for task in tasks :
1240+ task .cancel (msg = "KeyboardInterrupt" )
1241+ try :
1242+ loop .run_until_complete (asyncio .gather (* tasks ))
1243+ except asyncio .CancelledError :
1244+ logger .debug ("Tasks canceled." )
1245+ raise
12291246
12301247 def _iter_arrow (self , max_chunksize : Optional [int ] = None ) -> Iterator [Tuple [Key , pa .Table ]]:
12311248 formatter : TableFormatter = get_formatter (self .formatting .format_type ) if self .formatting else ArrowFormatter ()
0 commit comments