1616import torch ._inductor .config
1717import torch .nn as nn
1818
19+ from torch .distributed import launcher
20+
1921from torch .distributed .device_mesh import DeviceMesh
22+ from torch .distributed .elastic .multiprocessing .errors import record
23+ from torch .distributed .elastic .utils .distributed import get_free_port
24+ from torch .distributed .launcher .api import elastic_launch
2025
2126from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
2227
@@ -58,8 +63,8 @@ class BuilderArgs:
5863 distributed : bool = False
5964 num_gpus : int = 1
6065 num_nodes : int = 1
61- pp_dim : int = 1
62- tp_dim : int = 1
66+ pp : int = 1
67+ tp : int = 1
6368 is_chat_model : bool = False
6469 prefill_possible : bool = False
6570 dynamic_shapes : bool = False
@@ -164,8 +169,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
164169 distributed = getattr (args , "distributed" , False )
165170 num_gpus = getattr (args , "num_gpus" , 1 )
166171 num_nodes = getattr (args , "num_nodes" , 1 )
167- pp_dim = getattr (args , "pp_dim " , 1 )
168- tp_dim = getattr (args , "tp_dim " , 1 )
172+ pp = getattr (args , "pp " , 1 )
173+ tp = getattr (args , "tp " , 1 )
169174 return cls (
170175 checkpoint_dir = checkpoint_dir ,
171176 checkpoint_path = checkpoint_path ,
@@ -182,8 +187,8 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
182187 distributed = distributed ,
183188 num_gpus = num_gpus ,
184189 num_nodes = num_nodes ,
185- pp_dim = pp_dim ,
186- tp_dim = tp_dim ,
190+ pp = pp ,
191+ tp = tp ,
187192 is_chat_model = is_chat_model ,
188193 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
189194 max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -492,19 +497,70 @@ def _maybe_parellelize_model(
492497
493498
494499def _load_model (builder_args : BuilderArgs ) -> Model :
495- world_mesh , parallel_dims = _maybe_init_distributed (builder_args )
500+ # world_mesh, parallel_dims = _maybe_init_distributed(builder_args)
496501 if builder_args .gguf_path :
497502 model = _load_model_gguf (builder_args )
498- elif builder_args .use_distributed :
499- model = _init_model_on_meta_device (builder_args )
503+ # elif builder_args.use_distributed:
504+ # model = _init_model_on_meta_device(builder_args)
500505 else :
501506 model = _load_model_default (builder_args )
502- model = _maybe_parellelize_model (model , builder_args , world_mesh , parallel_dims )
507+ # model = _maybe_parellelize_model(model, builder_args, world_mesh, parallel_dims)
503508
504509 model = model .to (device = builder_args .device , dtype = builder_args .precision )
505510 return model .eval ()
506511
507512
513+ @record
514+ def run_main (local_rank ):
515+ # Add the directory containing the train file to sys.path
516+ train_file_path = Path (__file__ ).parent .parent .parent / "dist_run.py"
517+ print (f"******* { train_file_path = } " )
518+ sys .path .insert (0 , os .path .dirname (os .path .abspath (train_file_path )))
519+
520+ # Set environment variables for distributed training
521+ os .environ ["LOCAL_RANK" ] = str (local_rank )
522+ os .environ ["RANK" ] = str (
523+ local_rank # + kwargs.get("node_rank", 0) * num_processes_per_node
524+ )
525+ os .environ ["WORLD_SIZE" ] = str (4 * 1 ) # num_nodes)
526+
527+ # Execute the train file
528+ with open (train_file_path , "rb" ) as file :
529+ exec (compile (file .read (), train_file_path , "exec" ))
530+
531+
532+ def _launch_distributed_inference (builder_args : BuilderArgs ) -> None :
533+ # create programmatic elastic launch
534+ print ("Launching distributed inference ..." )
535+
536+ num_processes_per_node = 4 # builder_args.num_gpus + 1
537+
538+ lc = launcher .LaunchConfig (
539+ min_nodes = 1 ,
540+ max_nodes = 1 ,
541+ nproc_per_node = num_processes_per_node ,
542+ # run_id=str(uuid.uuid4()),
543+ rdzv_backend = "c10d" ,
544+ rdzv_endpoint = "localhost:29401" ,
545+ max_restarts = 0 ,
546+ monitor_interval = 1 ,
547+ )
548+
549+ train_file_path = Path (__file__ ).parent / "distributed" / "dist_run.py"
550+
551+ elastic_launch (
552+ config = lc ,
553+ entrypoint = run_main ,
554+ )(train_file_path )
555+ print (
556+ f"Done launching distributed inference on **4 ** { builder_args .num_gpus } GPUs."
557+ )
558+ # role=role, *args, **kwargs)
559+
560+ # assert False, "distributed inference is not supported yet"
561+ # pass
562+
563+
508564def _initialize_model (
509565 builder_args : BuilderArgs ,
510566 quantize ,
@@ -513,6 +569,10 @@ def _initialize_model(
513569 support_tensor_subclass : bool = True ,
514570) -> Model :
515571 print ("Loading model..." )
572+ if builder_args .distributed :
573+ # we part ways here with torchchat cli and move into dist inference
574+ _launch_distributed_inference (builder_args )
575+ return None
516576
517577 if builder_args .gguf_path and (builder_args .dso_path or builder_args .pte_path ):
518578 print ("Setting gguf_kwargs for generate." )
0 commit comments