@@ -274,8 +274,17 @@ def train(
274274 # Run initializers if configured
275275 if initializer :
276276 logger .debug ("Running initializers" )
277- self ._run_initializers (trainjob_name , initializer , workdir , network_id )
278- logger .debug ("Initializers completed successfully" )
277+ try :
278+ self ._run_initializers (trainjob_name , initializer , workdir , network_id )
279+ logger .debug ("Initializers completed successfully" )
280+ except Exception as e :
281+ # Clean up network if initializers fail
282+ logger .error (f"Initializer failed, cleaning up network: { e } " )
283+ from contextlib import suppress
284+
285+ with suppress (Exception ):
286+ self ._adapter .delete_network (network_id )
287+ raise
279288
280289 # Generate training script code (inline, not written to disk)
281290 training_script_code = container_utils .get_training_script_code (trainer )
@@ -493,7 +502,7 @@ def _run_initializers(
493502 RuntimeError: If initializer fails to complete successfully.
494503 """
495504 # Get initializer image
496- init_image = container_utils .get_initializer_image ()
505+ init_image = container_utils .get_initializer_image (self . cfg )
497506
498507 # Pull initializer image if needed
499508 container_utils .maybe_pull_image (self ._adapter , init_image , self .cfg .pull_policy )
@@ -586,32 +595,40 @@ def _run_single_initializer(
586595
587596 # Wait for the initializer to complete
588597 try :
589- import time
598+ # Use the wait API for efficient waiting
599+ exit_code = self ._adapter .wait_for_container (
600+ container_id , timeout = self .cfg .initializer_timeout
601+ )
590602
591- timeout = 600 # 10 minutes timeout for initialization
592- polling_interval = 2
593- elapsed = 0
603+ if exit_code == 0 :
604+ logger .debug (f"{ init_type } initializer completed successfully" )
605+ # Clean up the successful container
606+ from contextlib import suppress
607+
608+ with suppress (Exception ):
609+ self ._adapter .remove_container (container_id , force = True )
610+ return
611+ else :
612+ # Get logs for debugging
613+ logs = list (self ._adapter .container_logs (container_id , follow = False ))
614+ error_msg = (
615+ f"{ init_type } initializer failed with exit code { exit_code } . "
616+ f"Logs: { ' ' .join (logs [- 10 :]) if logs else 'No logs available' } "
617+ )
618+ raise RuntimeError (error_msg )
594619
595- while elapsed < timeout :
596- status , exit_code = self ._adapter .container_status (container_id )
620+ except TimeoutError :
621+ logger .error (
622+ f"{ init_type } initializer did not complete within "
623+ f"{ self .cfg .initializer_timeout } seconds"
624+ )
625+ # Clean up the timed-out container
626+ from contextlib import suppress
597627
598- if status == "exited" :
599- if exit_code == 0 :
600- logger .debug (f"{ init_type } initializer completed successfully" )
601- return
602- else :
603- # Get logs for debugging
604- logs = list (self ._adapter .container_logs (container_id , follow = False ))
605- error_msg = (
606- f"{ init_type } initializer failed with exit code { exit_code } . "
607- f"Logs: { ' ' .join (logs [- 10 :]) if logs else 'No logs available' } "
608- )
609- raise RuntimeError (error_msg )
610-
611- time .sleep (polling_interval )
612- elapsed += polling_interval
613-
614- raise TimeoutError (f"{ init_type } initializer did not complete within { timeout } seconds" )
628+ with suppress (Exception ):
629+ self ._adapter .stop_container (container_id , timeout = 5 )
630+ self ._adapter .remove_container (container_id , force = True )
631+ raise
615632
616633 except Exception as e :
617634 logger .error (f"Error running { init_type } initializer: { e } " )
0 commit comments