Skip to content

Commit 08addcc

Browse files
committed
feat: support ctrl+c in parallel process running
1 parent 09cb204 commit 08addcc

File tree

1 file changed

+102
-39
lines changed

1 file changed

+102
-39
lines changed

hud/datasets/execution/parallel.py

Lines changed: 102 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def _process_worker(
4040
2. Creates its own event loop
4141
3. Processes a batch of tasks asynchronously
4242
4. Returns results with their original indices
43+
5. Handles interruption signals gracefully
4344
4445
Args:
4546
task_batch: List of (index, task_dict) tuples
@@ -58,6 +59,7 @@ def _process_worker(
5859
List of (index, result) tuples
5960
"""
6061
# Import inside worker to avoid pickling issues
62+
import signal
6163
import sys
6264

6365
import hud
@@ -71,6 +73,14 @@ def _process_worker(
7173
sys.stderr.reconfigure(line_buffering=True) # type: ignore
7274
except AttributeError:
7375
pass
76+
77+
# Set up signal handler for clean interruption
78+
def signal_handler(signum, frame):
79+
logger.warning(f"Worker {worker_id}: Received interrupt signal")
80+
# Raise KeyboardInterrupt to actually interrupt the worker
81+
raise KeyboardInterrupt(f"Worker {worker_id} interrupted by user")
82+
83+
signal.signal(signal.SIGINT, signal_handler)
7484

7585
# Reinitialize telemetry in this process
7686
configure_telemetry()
@@ -157,8 +167,19 @@ async def process_single_task(index: int, task_dict: dict[str, Any]) -> tuple[in
157167
# Process all tasks in parallel within this process
158168
tasks = [process_single_task(idx, task_dict) for idx, task_dict in task_batch]
159169

160-
results = await asyncio.gather(*tasks, return_exceptions=False)
161-
return results
170+
try:
171+
results = await asyncio.gather(*tasks, return_exceptions=False)
172+
return results
173+
except asyncio.CancelledError:
174+
logger.info(f"Worker {worker_id}: Tasks cancelled due to interruption")
175+
# Return error results for all tasks
176+
return [(idx, {
177+
"error": "Task cancelled (Ctrl+C)",
178+
"isError": True,
179+
"reward": 0.0,
180+
"done": False,
181+
"content": "Task cancelled"
182+
}) for idx, _ in task_batch]
162183

163184
try:
164185
# Run the async batch processing
@@ -180,6 +201,19 @@ async def process_single_task(index: int, task_dict: dict[str, Any]) -> tuple[in
180201
logger.warning("Worker %s: Telemetry flush timed out", worker_id)
181202

182203
return results
204+
except KeyboardInterrupt:
205+
logger.info(f"Worker {worker_id}: Interrupted by user, stopping gracefully")
206+
# Return partial results for tasks that completed
207+
partial_results = []
208+
for idx, _ in task_batch:
209+
partial_results.append((idx, {
210+
"error": "Worker interrupted by user (Ctrl+C)",
211+
"isError": True,
212+
"reward": 0.0,
213+
"done": False,
214+
"content": "Task interrupted"
215+
}))
216+
return partial_results
183217
except Exception as e:
184218
logger.error("[Worker %s] Batch processing failed: %s", worker_id, e)
185219
logger.error("Worker %s batch processing failed: %s", worker_id, e)
@@ -365,7 +399,8 @@ async def run_dataset_parallel_manual(
365399
)
366400

367401
# Process batches in parallel using ProcessPoolExecutor
368-
with ProcessPoolExecutor(max_workers=max_workers) as executor:
402+
executor = ProcessPoolExecutor(max_workers=max_workers)
403+
try:
369404
# Submit all batches to workers
370405
future_to_batch = {
371406
executor.submit(worker_func, batch, worker_id=i): batch
@@ -377,48 +412,76 @@ async def run_dataset_parallel_manual(
377412
total = len(task_dicts)
378413

379414
# Process results as they complete
380-
for future in as_completed(future_to_batch):
381-
batch = future_to_batch[future]
382-
383-
try:
384-
# Get results from this worker
385-
batch_results = future.result()
386-
387-
# Place results in correct positions
388-
for index, result in batch_results:
389-
results[index] = result
390-
completed += 1
391-
392-
# Calculate success rate so far
393-
successful_so_far = sum(
394-
1
395-
for r in results[:completed]
396-
if r is not None and getattr(r, "reward", 0) > 0
397-
)
398-
399-
progress_msg = (
400-
f"Progress: {completed}/{total} tasks completed "
401-
f"({100 * completed / total:.1f}%) | "
402-
f"Success rate: {successful_so_far}/{completed} "
403-
f"({100 * successful_so_far / completed:.1f}%)"
404-
)
405-
406-
logger.info(progress_msg)
415+
try:
416+
for future in as_completed(future_to_batch):
417+
batch = future_to_batch[future]
418+
419+
try:
420+
# Get results from this worker
421+
batch_results = future.result()
422+
423+
# Place results in correct positions
424+
for index, result in batch_results:
425+
results[index] = result
426+
completed += 1
427+
428+
# Calculate success rate so far
429+
successful_so_far = sum(
430+
1
431+
for r in results[:completed]
432+
if r is not None and getattr(r, "reward", 0) > 0
433+
)
407434

408-
except Exception as e:
409-
# Handle worker failure
410-
logger.error("Worker failed with exception: %s\n%s", e, traceback.format_exc())
435+
progress_msg = (
436+
f"Progress: {completed}/{total} tasks completed "
437+
f"({100 * completed / total:.1f}%) | "
438+
f"Success rate: {successful_so_far}/{completed} "
439+
f"({100 * successful_so_far / completed:.1f}%)"
440+
)
411441

412-
# Mark all tasks in this batch as failed
413-
for index, _ in batch:
414-
results[index] = {
415-
"error": f"Worker process failed: {e}",
442+
logger.info(progress_msg)
443+
444+
except Exception as e:
445+
# Handle worker failure
446+
logger.error("Worker failed with exception: %s\n%s", e, traceback.format_exc())
447+
448+
# Mark all tasks in this batch as failed
449+
for index, _ in batch:
450+
results[index] = {
451+
"error": f"Worker process failed: {e}",
452+
"isError": True,
453+
"reward": 0.0,
454+
"done": False,
455+
"content": f"Worker process failed: {e}",
456+
}
457+
completed += 1
458+
459+
except KeyboardInterrupt:
460+
logger.warning("\n⚠️ Parallel evaluation interrupted by user (Ctrl+C)")
461+
logger.info("Cancelling pending tasks...")
462+
463+
# Cancel all pending futures
464+
for future in future_to_batch:
465+
if not future.done():
466+
future.cancel()
467+
468+
# Mark uncompleted tasks as interrupted
469+
for i, r in enumerate(results):
470+
if r is None:
471+
results[i] = {
472+
"error": "Evaluation interrupted by user",
416473
"isError": True,
417474
"reward": 0.0,
418475
"done": False,
419-
"content": f"Worker process failed: {e}",
476+
"content": "Task interrupted (Ctrl+C)",
420477
}
421-
completed += 1
478+
479+
logger.info(f"Interrupted after {completed}/{total} tasks")
480+
raise # Re-raise to propagate the interrupt
481+
482+
finally:
483+
# Always shutdown the executor properly
484+
executor.shutdown(wait=False, cancel_futures=True)
422485

423486
# Verify all results are populated
424487
missing = [i for i, r in enumerate(results) if r is None]

0 commit comments

Comments
 (0)