Skip to content

Commit 8944d40

Browse files
committed
refactor: avoid using model_wrapper improperly
1 parent 1dd84c7 commit 8944d40

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -338,10 +338,10 @@ def __init__(
338338
config: PretrainedConfig,
339339
neuron_config: NxDNeuronConfig,
340340
traced_model: torch.jit.ScriptModule,
341-
model_wrappers: list[NxDGraphBuilder],
341+
graph_builders: list[NxDGraphBuilder],
342342
):
343343
super().__init__(
344-
config=config, neuron_config=neuron_config, traced_model=traced_model, model_wrappers=model_wrappers
344+
config=config, neuron_config=neuron_config, traced_model=traced_model, graph_builders=graph_builders
345345
)
346346
ctx_neuron_config = NxDModelForCausalLM._create_context_encoding_config(neuron_config)
347347
self.context_encoding_model = NxDDecoderWrapper(
@@ -617,14 +617,14 @@ def _from_pretrained(
617617
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
618618
else:
619619
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
620-
model_builders = NxDModelForCausalLM.create_graph_builders(
620+
graph_builders = NxDModelForCausalLM.create_graph_builders(
621621
cls._model_cls, config=config, neuron_config=neuron_config
622622
)
623623
model = cls(
624624
config=config,
625625
neuron_config=neuron_config,
626626
traced_model=traced_model,
627-
model_wrappers=model_builders,
627+
graph_builders=graph_builders,
628628
)
629629
model.load_weights(
630630
model_id,
@@ -674,7 +674,7 @@ def _export(
674674
# Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+)
675675
if hasattr(config, "head_dim") and config.head_dim is None:
676676
config.head_dim = config.hidden_size // config.num_attention_heads
677-
model_builders = cls.create_graph_builders(
677+
graph_builders = cls.create_graph_builders(
678678
model_cls=cls._model_cls,
679679
config=config,
680680
neuron_config=neuron_config,
@@ -689,14 +689,14 @@ def _export(
689689
with hub_neuronx_cache(entry=cache_entry):
690690
traced_model = NxDPreTrainedModel.compile(
691691
neuron_config=neuron_config,
692-
model_wrappers=model_builders,
692+
graph_builders=graph_builders,
693693
compiler_args=cls.get_compiler_args(neuron_config),
694694
)
695695
model = cls(
696696
config=config,
697697
neuron_config=neuron_config,
698698
traced_model=traced_model,
699-
model_wrappers=model_builders,
699+
graph_builders=graph_builders,
700700
)
701701
if load_weights:
702702
model.load_weights(

optimum/neuron/models/inference/backend/pretrained_model.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_shards_path(dest_path):
4949

5050
def get_builder(
5151
neuron_config: NxDNeuronConfig,
52-
model_wrappers: dict[str, NxDGraphBuilder],
52+
graph_builders: dict[str, NxDGraphBuilder],
5353
debug: bool = False,
5454
checkpoint_loader=None,
5555
compiler_args: str = None,
@@ -63,7 +63,7 @@ def get_builder(
6363
6464
Args:
6565
neuron_config (NxDNeuronConfig): The Neuron configuration.
66-
model_wrappers (list[NxDGraphBuilder]): The model graphs to be added to the builder.
66+
graph_builders (list[NxDGraphBuilder]): The model graphs to be added to the builder.
6767
debug (bool): Whether to enable debug mode.
6868
checkpoint_loader (callable): A function to load the model's state dictionary and weights.
6969
compiler_args (str): Compiler arguments to be passed to the builder.
@@ -86,13 +86,13 @@ def get_builder(
8686
logical_nc_config=neuron_config.logical_nc_config,
8787
weights_to_skip_layout_optimization=neuron_config.weights_to_skip_layout_optimization,
8888
)
89-
for tag, model in model_wrappers.items():
89+
for tag, graph_builder in graph_builders.items():
9090
builder.add(
9191
key=tag,
92-
model_instance=model.get_model_instance(),
93-
example_inputs=model.input_generator(),
92+
model_instance=graph_builder.get_model_instance(),
93+
example_inputs=graph_builder.input_generator(),
9494
compiler_args=compiler_args,
95-
priority_model_idx=model.priority_model_idx,
95+
priority_model_idx=graph_builder.priority_model_idx,
9696
)
9797
return builder
9898

@@ -109,14 +109,14 @@ def __init__(
109109
config: PretrainedConfig,
110110
neuron_config: NxDNeuronConfig,
111111
traced_model: torch.jit.ScriptModule,
112-
model_wrappers: dict[str, NxDGraphBuilder],
112+
graph_builders: dict[str, NxDGraphBuilder],
113113
):
114114
self.config = copy.deepcopy(config)
115115
self.neuron_config = copy.deepcopy(neuron_config)
116116
# Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
117117
self.config.torch_dtype = self.neuron_config.torch_dtype
118118
self._traced_model = traced_model
119-
self.model_wrappers = model_wrappers # Required for loading weights
119+
self.graph_builders = graph_builders # Required for loading weights
120120

121121
# NxDPretrainedModel abstract API
122122
@abstractmethod
@@ -131,8 +131,8 @@ def get_compiler_args(cls, neuron_config) -> str | None:
131131
return None
132132

133133
@staticmethod
134-
def compile(neuron_config, model_wrappers: dict[str, NxDGraphBuilder], compiler_args: str, debug: bool = False):
135-
builder = get_builder(neuron_config, model_wrappers, debug=debug, compiler_args=compiler_args)
134+
def compile(neuron_config, graph_builders: dict[str, NxDGraphBuilder], compiler_args: str, debug: bool = False):
135+
builder = get_builder(neuron_config, graph_builders, debug=debug, compiler_args=compiler_args)
136136
return builder.trace(initialize_model_weights=False)
137137

138138
def save(self, dest_path, weight_path: str | None = None):
@@ -153,7 +153,7 @@ def shard_checkpoint(self, src_path, dest_path, debug: bool = False):
153153
checkpoint_loader = partial(self.checkpoint_loader_fn, src_path, self.config, self.neuron_config)
154154
sharder = get_builder(
155155
self.neuron_config,
156-
self.model_wrappers,
156+
self.graph_builders,
157157
debug=debug,
158158
checkpoint_loader=checkpoint_loader,
159159
compiler_args=self.get_compiler_args(self.neuron_config),
@@ -191,7 +191,7 @@ def get_shard_name(rank):
191191
checkpoint_loader = partial(self.checkpoint_loader_fn, weights_path, self.config, self.neuron_config)
192192
sharder = get_builder(
193193
self.neuron_config,
194-
self.model_wrappers,
194+
self.graph_builders,
195195
debug=False,
196196
checkpoint_loader=checkpoint_loader,
197197
compiler_args=self.get_compiler_args(self.neuron_config),

0 commit comments

Comments
 (0)