Skip to content

Commit 7a192fb

Browse files
committed
refactor: make create_graph_builders an abstract class method
1 parent 8944d40 commit 7a192fb

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -383,24 +383,26 @@ def _create_speculation_config(neuron_config: NxDNeuronConfig) -> NxDNeuronConfi
383383
spec_neuron_config.batch_size = neuron_config.tkg_batch_size
384384
return spec_neuron_config
385385

386-
@staticmethod
387-
def create_graph_builders(model_cls, config, neuron_config):
386+
@classmethod
387+
def create_graph_builders(cls, config, neuron_config):
388+
if cls._model_cls is None:
389+
raise SystemError(f"No underlying model class defined for {cls}.")
388390
graph_builders = {}
389391
ctx_neuron_config = NxDModelForCausalLM._create_context_encoding_config(neuron_config)
390392
graph_builders["context_encoding"] = NxDDecoderBuilder(
391393
config=config,
392394
neuron_config=ctx_neuron_config,
393395
max_tokens=ctx_neuron_config.max_context_length,
394396
active_tokens=ctx_neuron_config.max_context_length,
395-
model_cls=model_cls,
397+
model_cls=cls._model_cls,
396398
)
397399
tkg_neuron_config = NxDModelForCausalLM._create_token_generation_config(neuron_config)
398400
graph_builders["token_generation"] = NxDDecoderBuilder(
399401
config=config,
400402
neuron_config=tkg_neuron_config,
401403
max_tokens=tkg_neuron_config.sequence_length,
402404
active_tokens=1,
403-
model_cls=model_cls,
405+
model_cls=cls._model_cls,
404406
priority_model_idx=0, # to turn on weight layout optimization
405407
)
406408
if neuron_config.speculation_length > 0:
@@ -410,7 +412,7 @@ def create_graph_builders(model_cls, config, neuron_config):
410412
neuron_config=spec_neuron_config,
411413
max_tokens=spec_neuron_config.sequence_length,
412414
active_tokens=spec_neuron_config.speculation_length,
413-
model_cls=model_cls,
415+
model_cls=cls._model_cls,
414416
priority_model_idx=0, # to turn on weight layout optimization
415417
)
416418
return graph_builders
@@ -617,9 +619,7 @@ def _from_pretrained(
617619
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
618620
else:
619621
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
620-
graph_builders = NxDModelForCausalLM.create_graph_builders(
621-
cls._model_cls, config=config, neuron_config=neuron_config
622-
)
622+
graph_builders = NxDModelForCausalLM.create_graph_builders(config=config, neuron_config=neuron_config)
623623
model = cls(
624624
config=config,
625625
neuron_config=neuron_config,
@@ -647,7 +647,7 @@ def _export(
647647
force_download: bool | None = False,
648648
local_files_only: bool | None = False,
649649
trust_remote_code: bool | None = False,
650-
load_weights: bool = False,
650+
load_weights: bool | None = False,
651651
**kwargs,
652652
) -> "NeuronModelForCausalLM":
653653
if len(kwargs) > 0:
@@ -675,7 +675,6 @@ def _export(
675675
if hasattr(config, "head_dim") and config.head_dim is None:
676676
config.head_dim = config.hidden_size // config.num_attention_heads
677677
graph_builders = cls.create_graph_builders(
678-
model_cls=cls._model_cls,
679678
config=config,
680679
neuron_config=neuron_config,
681680
)

optimum/neuron/models/inference/backend/pretrained_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,15 @@ def get_compiler_args(cls, neuron_config) -> str | None:
130130
"""Gets the Neuron compiler arguments to use when compiling this model."""
131131
return None
132132

133+
@classmethod
134+
@abstractmethod
135+
def create_graph_builders(
136+
cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig
137+
) -> dict[str, NxDGraphBuilder]:
138+
raise NotImplementedError(
139+
"The child class must provide a method to return the model graph builders dictionary."
140+
)
141+
133142
@staticmethod
134143
def compile(neuron_config, graph_builders: dict[str, NxDGraphBuilder], compiler_args: str, debug: bool = False):
135144
builder = get_builder(neuron_config, graph_builders, debug=debug, compiler_args=compiler_args)

0 commit comments

Comments
 (0)