11import time , typing , threading , dataclasses
22import rich , rich .progress
3+ import traceback
34
45from .printer import cons
56
6-
7-
8-
97class WorkerThread (threading .Thread ):
108 def __init__ (self , * args , ** kwargs ):
119 self .exc = None
10+ self .exc_info = None # Store full exception information for better debugging
11+ self .completed_successfully = False # Track if the target function completed
1212
1313 threading .Thread .__init__ (self , * args , ** kwargs )
1414
1515 def run (self ):
1616 try :
1717 if self ._target :
1818 self ._target (* self ._args , ** self ._kwargs )
19+ self .completed_successfully = True # Mark as completed successfully
1920 except Exception as exc :
2021 self .exc = exc
22+ # Store the full traceback for better error reporting
23+ self .exc_info = traceback .format_exc ()
2124
2225
2326@dataclasses .dataclass
@@ -35,7 +38,6 @@ class Task:
3538 args : typing .List [typing .Any ]
3639 load : float
3740
38-
3941def sched (tasks : typing .List [Task ], nThreads : int , devices : typing .Set [int ] = None ) -> None :
4042 nAvailable : int = nThreads
4143 threads : typing .List [WorkerThreadHolder ] = []
@@ -46,9 +48,39 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
4648 nonlocal threads , nAvailable
4749
4850 for threadID , threadHolder in enumerate (threads ):
49- if not threadHolder .thread .is_alive ():
51+ # Check if thread is not alive OR if it's been running for too long
52+ thread_not_alive = not threadHolder .thread .is_alive ()
53+
54+ if thread_not_alive :
55+ # Properly join the thread with timeout to prevent infinite hangs
56+ try :
57+ threadHolder .thread .join (timeout = 30.0 ) # 30 second timeout
58+
59+ # Double-check that thread actually finished joining
60+ if threadHolder .thread .is_alive ():
61+ # Thread didn't finish within timeout - this is a serious issue
62+ raise RuntimeError (f"Thread { threadID } failed to join within 30 seconds timeout. "
63+ f"Thread may be hung or in an inconsistent state." )
64+
65+ except Exception as join_exc :
66+ # Handle join-specific exceptions with more context
67+ raise RuntimeError (f"Failed to join thread { threadID } : { join_exc } . "
68+ f"This may indicate a system threading issue or hung test case." ) from join_exc
69+
70+ # Check for and propagate any exceptions that occurred in the worker thread
71+ # But only if the worker function didn't complete successfully
72+ # (This allows test failures to be handled gracefully by handle_case)
5073 if threadHolder .thread .exc is not None :
51- raise threadHolder .thread .exc
74+ if threadHolder .thread .completed_successfully :
75+ # Test framework handled the exception gracefully (e.g., test failure)
76+ # Don't re-raise - this is expected behavior
77+ pass
78+ # Unhandled exception - this indicates a real problem
79+ elif hasattr (threadHolder .thread , 'exc_info' ) and threadHolder .thread .exc_info :
80+ error_msg = f"Worker thread { threadID } failed with unhandled exception:\n { threadHolder .thread .exc_info } "
81+ raise RuntimeError (error_msg ) from threadHolder .thread .exc
82+ else :
83+ raise threadHolder .thread .exc
5284
5385 nAvailable += threadHolder .ppn
5486 for device in threadHolder .devices or set ():
@@ -60,7 +92,6 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
6092
6193 break
6294
63-
6495 with rich .progress .Progress (console = cons .raw , transient = True ) as progress :
6596 queue_tracker = progress .add_task ("Queued " , total = len (tasks ))
6697 complete_tracker = progress .add_task ("Completed" , total = len (tasks ))
@@ -99,8 +130,7 @@ def join_first_dead_thread(progress, complete_tracker) -> None:
99130
100131 threads .append (WorkerThreadHolder (thread , task .ppn , task .load , use_devices ))
101132
102-
103- # Wait for the lasts tests to complete
133+ # Wait for the last tests to complete (MOVED INSIDE CONTEXT)
104134 while len (threads ) != 0 :
105135 # Keep track of threads that are done
106136 join_first_dead_thread (progress , complete_tracker )
0 commit comments