@@ -49,7 +49,7 @@ def get_shards_path(dest_path):
4949
5050def get_builder (
5151 neuron_config : NxDNeuronConfig ,
52- model_wrappers : dict [str , NxDGraphBuilder ],
52+ graph_builders : dict [str , NxDGraphBuilder ],
5353 debug : bool = False ,
5454 checkpoint_loader = None ,
5555 compiler_args : str = None ,
@@ -63,7 +63,7 @@ def get_builder(
6363
6464 Args:
6565 neuron_config (NxDNeuronConfig): The Neuron configuration.
66- model_wrappers (list[NxDGraphBuilder]): The model graphs to be added to the builder.
66+ graph_builders (list[NxDGraphBuilder]): The model graphs to be added to the builder.
6767 debug (bool): Whether to enable debug mode.
6868 checkpoint_loader (callable): A function to load the model's state dictionary and weights.
6969 compiler_args (str): Compiler arguments to be passed to the builder.
@@ -86,13 +86,13 @@ def get_builder(
8686 logical_nc_config = neuron_config .logical_nc_config ,
8787 weights_to_skip_layout_optimization = neuron_config .weights_to_skip_layout_optimization ,
8888 )
89- for tag , model in model_wrappers .items ():
89+ for tag , graph_builder in graph_builders .items ():
9090 builder .add (
9191 key = tag ,
92- model_instance = model .get_model_instance (),
93- example_inputs = model .input_generator (),
92+ model_instance = graph_builder .get_model_instance (),
93+ example_inputs = graph_builder .input_generator (),
9494 compiler_args = compiler_args ,
95- priority_model_idx = model .priority_model_idx ,
95+ priority_model_idx = graph_builder .priority_model_idx ,
9696 )
9797 return builder
9898
@@ -109,14 +109,14 @@ def __init__(
109109 config : PretrainedConfig ,
110110 neuron_config : NxDNeuronConfig ,
111111 traced_model : torch .jit .ScriptModule ,
112- model_wrappers : dict [str , NxDGraphBuilder ],
112+ graph_builders : dict [str , NxDGraphBuilder ],
113113 ):
114114 self .config = copy .deepcopy (config )
115115 self .neuron_config = copy .deepcopy (neuron_config )
116116 # Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
117117 self .config .torch_dtype = self .neuron_config .torch_dtype
118118 self ._traced_model = traced_model
119- self .model_wrappers = model_wrappers # Required for loading weights
119+ self .graph_builders = graph_builders # Required for loading weights
120120
121121 # NxDPretrainedModel abstract API
122122 @abstractmethod
@@ -131,8 +131,8 @@ def get_compiler_args(cls, neuron_config) -> str | None:
131131 return None
132132
133133 @staticmethod
134- def compile (neuron_config , model_wrappers : dict [str , NxDGraphBuilder ], compiler_args : str , debug : bool = False ):
135- builder = get_builder (neuron_config , model_wrappers , debug = debug , compiler_args = compiler_args )
134+ def compile (neuron_config , graph_builders : dict [str , NxDGraphBuilder ], compiler_args : str , debug : bool = False ):
135+ builder = get_builder (neuron_config , graph_builders , debug = debug , compiler_args = compiler_args )
136136 return builder .trace (initialize_model_weights = False )
137137
138138 def save (self , dest_path , weight_path : str | None = None ):
@@ -153,7 +153,7 @@ def shard_checkpoint(self, src_path, dest_path, debug: bool = False):
153153 checkpoint_loader = partial (self .checkpoint_loader_fn , src_path , self .config , self .neuron_config )
154154 sharder = get_builder (
155155 self .neuron_config ,
156- self .model_wrappers ,
156+ self .graph_builders ,
157157 debug = debug ,
158158 checkpoint_loader = checkpoint_loader ,
159159 compiler_args = self .get_compiler_args (self .neuron_config ),
@@ -191,7 +191,7 @@ def get_shard_name(rank):
191191 checkpoint_loader = partial (self .checkpoint_loader_fn , weights_path , self .config , self .neuron_config )
192192 sharder = get_builder (
193193 self .neuron_config ,
194- self .model_wrappers ,
194+ self .graph_builders ,
195195 debug = False ,
196196 checkpoint_loader = checkpoint_loader ,
197197 compiler_args = self .get_compiler_args (self .neuron_config ),
0 commit comments