@@ -40,6 +40,7 @@ def _process_worker(
4040 2. Creates its own event loop
4141 3. Processes a batch of tasks asynchronously
4242 4. Returns results with their original indices
43+ 5. Handles interruption signals gracefully
4344
4445 Args:
4546 task_batch: List of (index, task_dict) tuples
@@ -58,6 +59,7 @@ def _process_worker(
5859 List of (index, result) tuples
5960 """
6061 # Import inside worker to avoid pickling issues
62+ import signal
6163 import sys
6264
6365 import hud
@@ -71,6 +73,14 @@ def _process_worker(
7173 sys .stderr .reconfigure (line_buffering = True ) # type: ignore
7274 except AttributeError :
7375 pass
76+
77+ # Set up signal handler for clean interruption
78+ def signal_handler (signum , frame ):
79+ logger .warning (f"Worker { worker_id } : Received interrupt signal" )
80+ # Raise KeyboardInterrupt to actually interrupt the worker
81+ raise KeyboardInterrupt (f"Worker { worker_id } interrupted by user" )
82+
83+ signal .signal (signal .SIGINT , signal_handler )
7484
7585 # Reinitialize telemetry in this process
7686 configure_telemetry ()
@@ -157,8 +167,19 @@ async def process_single_task(index: int, task_dict: dict[str, Any]) -> tuple[in
157167 # Process all tasks in parallel within this process
158168 tasks = [process_single_task (idx , task_dict ) for idx , task_dict in task_batch ]
159169
160- results = await asyncio .gather (* tasks , return_exceptions = False )
161- return results
170+ try :
171+ results = await asyncio .gather (* tasks , return_exceptions = False )
172+ return results
173+ except asyncio .CancelledError :
174+ logger .info (f"Worker { worker_id } : Tasks cancelled due to interruption" )
175+ # Return error results for all tasks
176+ return [(idx , {
177+ "error" : "Task cancelled (Ctrl+C)" ,
178+ "isError" : True ,
179+ "reward" : 0.0 ,
180+ "done" : False ,
181+ "content" : "Task cancelled"
182+ }) for idx , _ in task_batch ]
162183
163184 try :
164185 # Run the async batch processing
@@ -180,6 +201,19 @@ async def process_single_task(index: int, task_dict: dict[str, Any]) -> tuple[in
180201 logger .warning ("Worker %s: Telemetry flush timed out" , worker_id )
181202
182203 return results
204+ except KeyboardInterrupt :
205+ logger .info (f"Worker { worker_id } : Interrupted by user, stopping gracefully" )
206+ # Return partial results for tasks that completed
207+ partial_results = []
208+ for idx , _ in task_batch :
209+ partial_results .append ((idx , {
210+ "error" : "Worker interrupted by user (Ctrl+C)" ,
211+ "isError" : True ,
212+ "reward" : 0.0 ,
213+ "done" : False ,
214+ "content" : "Task interrupted"
215+ }))
216+ return partial_results
183217 except Exception as e :
184218 logger .error ("[Worker %s] Batch processing failed: %s" , worker_id , e )
185219 logger .error ("Worker %s batch processing failed: %s" , worker_id , e )
@@ -365,7 +399,8 @@ async def run_dataset_parallel_manual(
365399 )
366400
367401 # Process batches in parallel using ProcessPoolExecutor
368- with ProcessPoolExecutor (max_workers = max_workers ) as executor :
402+ executor = ProcessPoolExecutor (max_workers = max_workers )
403+ try :
369404 # Submit all batches to workers
370405 future_to_batch = {
371406 executor .submit (worker_func , batch , worker_id = i ): batch
@@ -377,48 +412,76 @@ async def run_dataset_parallel_manual(
377412 total = len (task_dicts )
378413
379414 # Process results as they complete
380- for future in as_completed (future_to_batch ):
381- batch = future_to_batch [future ]
382-
383- try :
384- # Get results from this worker
385- batch_results = future .result ()
386-
387- # Place results in correct positions
388- for index , result in batch_results :
389- results [index ] = result
390- completed += 1
391-
392- # Calculate success rate so far
393- successful_so_far = sum (
394- 1
395- for r in results [:completed ]
396- if r is not None and getattr (r , "reward" , 0 ) > 0
397- )
398-
399- progress_msg = (
400- f"Progress: { completed } /{ total } tasks completed "
401- f"({ 100 * completed / total :.1f} %) | "
402- f"Success rate: { successful_so_far } /{ completed } "
403- f"({ 100 * successful_so_far / completed :.1f} %)"
404- )
405-
406- logger .info (progress_msg )
415+ try :
416+ for future in as_completed (future_to_batch ):
417+ batch = future_to_batch [future ]
418+
419+ try :
420+ # Get results from this worker
421+ batch_results = future .result ()
422+
423+ # Place results in correct positions
424+ for index , result in batch_results :
425+ results [index ] = result
426+ completed += 1
427+
428+ # Calculate success rate so far
429+ successful_so_far = sum (
430+ 1
431+ for r in results [:completed ]
432+ if r is not None and getattr (r , "reward" , 0 ) > 0
433+ )
407434
408- except Exception as e :
409- # Handle worker failure
410- logger .error ("Worker failed with exception: %s\n %s" , e , traceback .format_exc ())
435+ progress_msg = (
436+ f"Progress: { completed } /{ total } tasks completed "
437+ f"({ 100 * completed / total :.1f} %) | "
438+ f"Success rate: { successful_so_far } /{ completed } "
439+ f"({ 100 * successful_so_far / completed :.1f} %)"
440+ )
411441
412- # Mark all tasks in this batch as failed
413- for index , _ in batch :
414- results [index ] = {
415- "error" : f"Worker process failed: { e } " ,
442+ logger .info (progress_msg )
443+
444+ except Exception as e :
445+ # Handle worker failure
446+ logger .error ("Worker failed with exception: %s\n %s" , e , traceback .format_exc ())
447+
448+ # Mark all tasks in this batch as failed
449+ for index , _ in batch :
450+ results [index ] = {
451+ "error" : f"Worker process failed: { e } " ,
452+ "isError" : True ,
453+ "reward" : 0.0 ,
454+ "done" : False ,
455+ "content" : f"Worker process failed: { e } " ,
456+ }
457+ completed += 1
458+
459+ except KeyboardInterrupt :
460+ logger .warning ("\n ⚠️ Parallel evaluation interrupted by user (Ctrl+C)" )
461+ logger .info ("Cancelling pending tasks..." )
462+
463+ # Cancel all pending futures
464+ for future in future_to_batch :
465+ if not future .done ():
466+ future .cancel ()
467+
468+ # Mark uncompleted tasks as interrupted
469+ for i , r in enumerate (results ):
470+ if r is None :
471+ results [i ] = {
472+ "error" : "Evaluation interrupted by user" ,
416473 "isError" : True ,
417474 "reward" : 0.0 ,
418475 "done" : False ,
419- "content" : f"Worker process failed: { e } " ,
476+ "content" : "Task interrupted (Ctrl+C) " ,
420477 }
421- completed += 1
478+
479+ logger .info (f"Interrupted after { completed } /{ total } tasks" )
480+ raise # Re-raise to propagate the interrupt
481+
482+ finally :
483+ # Always shutdown the executor properly
484+ executor .shutdown (wait = False , cancel_futures = True )
422485
423486 # Verify all results are populated
424487 missing = [i for i , r in enumerate (results ) if r is None ]
0 commit comments