@@ -313,6 +313,7 @@ def _worker_loop(
313313
314314 traceback .print_exc ()
315315 logger .error (f"Error in worker loop (rank { rank } ): { e } " )
316+ queue_out .put (e ) # any exception caught in the worker will be raised to the main process
316317 finally :
317318 del module
318319 torch .cuda .synchronize ()
@@ -365,29 +366,44 @@ def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
365366 }
366367 )
367368 try :
368- _ = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
369+ res = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
370+ if isinstance (res , Exception ):
371+ raise res
369372 except Empty :
370- logger .error ("Parallel model load LoRA timeout" )
371- raise RuntimeError ("Parallel model load LoRA timeout" )
372- logger .info ("Parallel model load LoRA done" )
373+ logger .error ("ParallelModel load LoRA timeout" )
374+ raise RuntimeError ("ParallelModel load LoRA timeout" )
375+ except Exception as e :
376+ logger .error (f"ParallelModel load LoRA error: { e } " )
377+ raise RuntimeError (f"ParallelModel load LoRA error: { e } " )
378+ logger .info ("ParallelModel load LoRA done" )
373379
374380 def unload_loras (self ):
375381 self .queue_in .put ({"method" : "unload_loras" })
376382 try :
377- _ = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
383+ res = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
384+ if isinstance (res , Exception ):
385+ raise res
378386 except Empty :
379- logger .error ("Parallel model unload LoRA timeout" )
380- raise RuntimeError ("Parallel model unload LoRA timeout" )
381- logger .info ("Parallel model unload LoRA done" )
387+ logger .error ("ParallelModel unload LoRA timeout" )
388+ raise RuntimeError ("ParallelModel unload LoRA timeout" )
389+ except Exception as e :
390+ logger .error (f"ParallelModel unload LoRA error: { e } " )
391+ raise RuntimeError (f"ParallelModel unload LoRA error: { e } " )
392+ logger .info ("ParallelModel unload LoRA done" )
382393
383394 def forward (self , ** kwargs ):
384395 self .queue_in .put (kwargs )
385396 try :
386- y = self .queue_out .get (timeout = PARALLEL_FWD_TIMEOUT_SEC )
397+ res = self .queue_out .get (timeout = PARALLEL_FWD_TIMEOUT_SEC )
398+ if isinstance (res , Exception ):
399+ raise res
387400 except Empty :
388- logger .error ("Parallel model forward timeout" )
389- raise RuntimeError ("Parallel model forward timeout" )
390- return y
401+ logger .error ("ParallelModel forward timeout" )
402+ raise RuntimeError ("ParallelModel forward timeout" )
403+ except Exception as e :
404+ logger .error (f"ParallelModel forward error: { e } " )
405+ raise RuntimeError (f"ParallelModel forward error: { e } " )
406+ return res
391407
392408 def __del__ (self ):
393409 # Send terminate signal to all workers
0 commit comments