From 10dd27a6af79b015d37491b3feb7a144d91c1722 Mon Sep 17 00:00:00 2001 From: Spencer Bryngelson Date: Tue, 8 Jul 2025 09:11:38 -0400 Subject: [PATCH 1/2] fix slurm hangs when a job fails --- toolchain/mfc/sched.py | 49 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/toolchain/mfc/sched.py b/toolchain/mfc/sched.py index 39ae78ea57..115568468c 100644 --- a/toolchain/mfc/sched.py +++ b/toolchain/mfc/sched.py @@ -1,14 +1,14 @@ import time, typing, threading, dataclasses import rich, rich.progress +import traceback from .printer import cons - - - class WorkerThread(threading.Thread): def __init__(self, *args, **kwargs): self.exc = None + self.exc_info = None # Store full exception information for better debugging + self.completed_successfully = False # Track if the target function completed threading.Thread.__init__(self, *args, **kwargs) @@ -16,8 +16,11 @@ def run(self): try: if self._target: self._target(*self._args, **self._kwargs) + self.completed_successfully = True # Mark as completed successfully except Exception as exc: self.exc = exc + # Store the full traceback for better error reporting + self.exc_info = traceback.format_exc() @dataclasses.dataclass @@ -35,7 +38,6 @@ class Task: args: typing.List[typing.Any] load: float - def sched(tasks: typing.List[Task], nThreads: int, devices: typing.Set[int] = None) -> None: nAvailable: int = nThreads threads: typing.List[WorkerThreadHolder] = [] @@ -46,9 +48,40 @@ def join_first_dead_thread(progress, complete_tracker) -> None: nonlocal threads, nAvailable for threadID, threadHolder in enumerate(threads): - if not threadHolder.thread.is_alive(): + # Check if thread is not alive OR if it's been running for too long + thread_not_alive = not threadHolder.thread.is_alive() + + if thread_not_alive: + # Properly join the thread with timeout to prevent infinite hangs + try: + threadHolder.thread.join(timeout=30.0) # 30 second timeout + + # Double-check that thread actually finished joining + if threadHolder.thread.is_alive(): + # Thread didn't finish within timeout - this is a serious issue + raise Exception(f"Thread {threadID} failed to join within 30 seconds timeout. " + f"Thread may be hung or in an inconsistent state.") + + except Exception as join_exc: + # Handle join-specific exceptions with more context + raise Exception(f"Failed to join thread {threadID}: {join_exc}. " + f"This may indicate a system threading issue or hung test case.") + + # Check for and propagate any exceptions that occurred in the worker thread + # But only if the worker function didn't complete successfully + # (This allows test failures to be handled gracefully by handle_case) if threadHolder.thread.exc is not None: - raise threadHolder.thread.exc + if threadHolder.thread.completed_successfully: + # Test framework handled the exception gracefully (e.g., test failure) + # Don't re-raise - this is expected behavior + pass + else: + # Unhandled exception - this indicates a real problem + if hasattr(threadHolder.thread, 'exc_info') and threadHolder.thread.exc_info: + error_msg = f"Worker thread {threadID} failed with unhandled exception:\n{threadHolder.thread.exc_info}" + raise Exception(error_msg) from threadHolder.thread.exc + else: + raise threadHolder.thread.exc nAvailable += threadHolder.ppn for device in threadHolder.devices or set(): @@ -60,7 +93,6 @@ def join_first_dead_thread(progress, complete_tracker) -> None: break - with rich.progress.Progress(console=cons.raw, transient=True) as progress: queue_tracker = progress.add_task("Queued ", total=len(tasks)) complete_tracker = progress.add_task("Completed", total=len(tasks)) @@ -99,8 +131,7 @@ def join_first_dead_thread(progress, complete_tracker) -> None: threads.append(WorkerThreadHolder(thread, task.ppn, task.load, use_devices)) - - # Wait for the lasts tests to complete + # Wait for the last tests to complete (MOVED INSIDE CONTEXT) while len(threads) != 0: # Keep track of threads that are done join_first_dead_thread(progress, complete_tracker) From 7dfabd890add2553dffccfe7e6c8b5e670646059 Mon Sep 17 00:00:00 2001 From: Spencer Bryngelson Date: Tue, 8 Jul 2025 09:51:44 -0400 Subject: [PATCH 2/2] fix lint --- toolchain/mfc/sched.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/toolchain/mfc/sched.py b/toolchain/mfc/sched.py index 115568468c..45ec5affb2 100644 --- a/toolchain/mfc/sched.py +++ b/toolchain/mfc/sched.py @@ -50,23 +50,23 @@ def join_first_dead_thread(progress, complete_tracker) -> None: for threadID, threadHolder in enumerate(threads): # Check if thread is not alive OR if it's been running for too long thread_not_alive = not threadHolder.thread.is_alive() - + if thread_not_alive: # Properly join the thread with timeout to prevent infinite hangs try: threadHolder.thread.join(timeout=30.0) # 30 second timeout - + # Double-check that thread actually finished joining if threadHolder.thread.is_alive(): # Thread didn't finish within timeout - this is a serious issue - raise Exception(f"Thread {threadID} failed to join within 30 seconds timeout. " - f"Thread may be hung or in an inconsistent state.") - + raise RuntimeError(f"Thread {threadID} failed to join within 30 seconds timeout. " + f"Thread may be hung or in an inconsistent state.") + except Exception as join_exc: # Handle join-specific exceptions with more context - raise Exception(f"Failed to join thread {threadID}: {join_exc}. " - f"This may indicate a system threading issue or hung test case.") - + raise RuntimeError(f"Failed to join thread {threadID}: {join_exc}. " + f"This may indicate a system threading issue or hung test case.") from join_exc + # Check for and propagate any exceptions that occurred in the worker thread # But only if the worker function didn't complete successfully # (This allows test failures to be handled gracefully by handle_case) @@ -75,13 +75,12 @@ def join_first_dead_thread(progress, complete_tracker) -> None: # Test framework handled the exception gracefully (e.g., test failure) # Don't re-raise - this is expected behavior pass + # Unhandled exception - this indicates a real problem + elif hasattr(threadHolder.thread, 'exc_info') and threadHolder.thread.exc_info: + error_msg = f"Worker thread {threadID} failed with unhandled exception:\n{threadHolder.thread.exc_info}" + raise RuntimeError(error_msg) from threadHolder.thread.exc else: - # Unhandled exception - this indicates a real problem - if hasattr(threadHolder.thread, 'exc_info') and threadHolder.thread.exc_info: - error_msg = f"Worker thread {threadID} failed with unhandled exception:\n{threadHolder.thread.exc_info}" - raise Exception(error_msg) from threadHolder.thread.exc - else: - raise threadHolder.thread.exc + raise threadHolder.thread.exc nAvailable += threadHolder.ppn for device in threadHolder.devices or set():