1616import torch ._inductor .config
1717import torch .nn as nn
1818
19- from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
20-
21- from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
22-
2319from torch .distributed .device_mesh import DeviceMesh
2420
25- from torchtune .models .convert_weights import meta_to_tune
26-
27- from torchtune .training import set_default_dtype
21+ from torchchat .distributed import launch_distributed , ParallelDims , parallelize_llama
2822
2923from torchchat .model import Model , ModelArgs , ModelType
3024
31- from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
32-
3325from torchchat .model_config .model_config import resolve_model_config
3426from torchchat .utils .build_utils import (
3527 device_sync ,
4032from torchchat .utils .measure_time import measure_time
4133from torchchat .utils .quantize import quantize_model
4234
35+ from torchtune .models .convert_weights import meta_to_tune
36+
37+ from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
38+
39+ from torchtune .models .llama3_2_vision ._convert_weights import llama3_vision_meta_to_tune
40+
41+ from torchtune .training import set_default_dtype
42+
4343
4444@dataclass
4545class BuilderArgs :
@@ -55,7 +55,10 @@ class BuilderArgs:
5555 device : Optional [str ] = None
5656 precision : torch .dtype = torch .float32
5757 setup_caches : bool = False
58- use_distributed : bool = False
58+ distributed : bool = False
59+ num_gpus : int = 1
60+ num_nodes : int = 1
61+ pp_dim : int = 1
5962 is_chat_model : bool = False
6063 prefill_possible : bool = False
6164 dynamic_shapes : bool = False
@@ -156,7 +159,11 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
156159 dtype = torch .float16
157160 else :
158161 dtype = name_to_dtype (args .dtype , args .device )
159-
162+ # distributed args
163+ distributed = getattr (args , "distributed" , False )
164+ num_gpus = getattr (args , "num_gpus" , 1 )
165+ num_nodes = getattr (args , "num_nodes" , 1 )
166+ pp_dim = getattr (args , "pp_dim" , 1 )
160167 return cls (
161168 checkpoint_dir = checkpoint_dir ,
162169 checkpoint_path = checkpoint_path ,
@@ -170,7 +177,10 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
170177 device = args .device ,
171178 precision = dtype ,
172179 setup_caches = (output_dso_path or output_pte_path ),
173- use_distributed = args .distributed ,
180+ distributed = distributed ,
181+ num_gpus = num_gpus ,
182+ num_nodes = num_nodes ,
183+ pp_dim = pp_dim ,
174184 is_chat_model = is_chat_model ,
175185 dynamic_shapes = getattr (args , "dynamic_shapes" , False ),
176186 max_seq_length = getattr (args , "max_seq_length" , None ),
@@ -400,10 +410,10 @@ def _load_model_default(builder_args: BuilderArgs) -> Model:
400410 # does not host any actual values, need to reinitialize them in the actual
401411 # device. Only do those buffer initialization, without initializing the entire
402412 # model.
403- decoder_config = model .config .transformer_args [' decoder' ]
404- head_dim = decoder_config [' embed_dim' ] // decoder_config [' num_heads' ]
405- max_seq_len = decoder_config [' max_seq_len' ]
406- rope_base = decoder_config [' rope_base' ]
413+ decoder_config = model .config .transformer_args [" decoder" ]
414+ head_dim = decoder_config [" embed_dim" ] // decoder_config [" num_heads" ]
415+ max_seq_len = decoder_config [" max_seq_len" ]
416+ rope_base = decoder_config [" rope_base" ]
407417 for submodule in model .modules ():
408418 if isinstance (submodule , Llama3ScaledRoPE ):
409419 submodule .__init__ (head_dim , max_seq_len , rope_base )
@@ -491,6 +501,7 @@ def _load_model(builder_args: BuilderArgs) -> Model:
491501 model = model .to (device = builder_args .device , dtype = builder_args .precision )
492502 return model .eval ()
493503
504+
494505def _initialize_model (
495506 builder_args : BuilderArgs ,
496507 quantize ,
0 commit comments