@@ -510,23 +510,29 @@ def _load_model(builder_args: BuilderArgs) -> Model:
510510 return model .eval ()
511511
512512
513- @record
514- def run_main (local_rank ):
515- # Add the directory containing the train file to sys.path
516- train_file_path = Path (__file__ ).parent .parent .parent / "dist_run.py"
517- print (f"******* { train_file_path = } " )
518- sys .path .insert (0 , os .path .dirname (os .path .abspath (train_file_path )))
513+ import importlib .util
514+ import subprocess
515+
519516
520- # Set environment variables for distributed training
521- os .environ ["LOCAL_RANK" ] = str (local_rank )
522- os .environ ["RANK" ] = str (
523- local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node
517+ def run_script (script_path , * args ):
518+ # Construct the command to run the script
519+ cmd = [sys .executable , script_path ] + list (args )
520+
521+ # Run the script as a subprocess
522+ process = subprocess .Popen (
523+ cmd , stdout = subprocess .PIPE , stderr = subprocess .PIPE , universal_newlines = True
524524 )
525- os .environ ["WORLD_SIZE" ] = str (4 * 1 ) # num_nodes)
526525
527- # Execute the train file
528- with open (train_file_path , "rb" ) as file :
529- exec (compile (file .read (), train_file_path , "exec" ))
526+ # Stream the output in real-time
527+ for line in process .stdout :
528+ print (line , end = "" )
529+ for line in process .stderr :
530+ print (line , end = "" , file = sys .stderr )
531+
532+ # Wait for the process to complete and get the return code
533+ return_code = process .wait ()
534+ if return_code != 0 :
535+ raise subprocess .CalledProcessError (return_code , cmd )
530536
531537
532538def _launch_distributed_inference (builder_args : BuilderArgs ) -> None :
@@ -546,12 +552,20 @@ def _launch_distributed_inference(builder_args: BuilderArgs) -> None:
546552 monitor_interval = 1 ,
547553 )
548554
549- train_file_path = Path (__file__ ).parent / "distributed" / "dist_run.py"
555+ train_file_path = Path (__file__ ).parent .parent .parent / "dist_run.py"
556+ print (f"train_file_path: { train_file_path } " )
557+ # import argparse
558+
559+ # parser2 = argparse.ArgumentParser()
560+
561+ # args = parser2.parse_args()
562+ args = []
563+ print (f"args: { args } " )
550564
551565 elastic_launch (
552566 config = lc ,
553- entrypoint = run_main ,
554- )(train_file_path )
567+ entrypoint = run_script ,
568+ )(train_file_path , * args )
555569 print (
556570 f"Done launching distributed inference on **4 ** { builder_args .num_gpus } GPUs."
557571 )
0 commit comments