Skip to content

Commit 1dd84c7

Browse files
committed
refactor: sanitize tag usage
1 parent 71aded3 commit 1dd84c7

File tree

5 files changed

+25
-60
lines changed

5 files changed

+25
-60
lines changed

optimum/neuron/models/inference/backend/graph_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919

2020

2121
class NxDGraphBuilder(ABC):
22-
def __init__(self, tag: str, priority_model_idx: int):
22+
def __init__(self, priority_model_idx: int):
2323
super().__init__()
24-
self.tag = tag
2524
self.priority_model_idx = priority_model_idx
2625

2726
@abstractmethod

optimum/neuron/models/inference/backend/modules/decoder/decoder_builder.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ def __init__(
3030
max_tokens: int,
3131
active_tokens: int,
3232
model_cls,
33-
tag="",
3433
priority_model_idx: int = None,
3534
) -> None:
36-
super().__init__(tag, priority_model_idx)
35+
super().__init__(priority_model_idx)
3736
self.config = config
3837
self.neuron_config = neuron_config
3938
self.max_tokens = max_tokens

optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -384,69 +384,36 @@ def _create_speculation_config(neuron_config: NxDNeuronConfig) -> NxDNeuronConfi
384384
return spec_neuron_config
385385

386386
@staticmethod
387-
def _create_context_encoding_builder(model_cls, config, neuron_config):
387+
def create_graph_builders(model_cls, config, neuron_config):
388+
graph_builders = {}
388389
ctx_neuron_config = NxDModelForCausalLM._create_context_encoding_config(neuron_config)
389-
390-
return NxDDecoderBuilder(
390+
graph_builders["context_encoding"] = NxDDecoderBuilder(
391391
config=config,
392392
neuron_config=ctx_neuron_config,
393393
max_tokens=ctx_neuron_config.max_context_length,
394394
active_tokens=ctx_neuron_config.max_context_length,
395395
model_cls=model_cls,
396-
tag=CONTEXT_ENCODING_MODEL_TAG,
397396
)
398-
399-
@staticmethod
400-
def _create_token_generation_builder(model_cls, config, neuron_config, enable_wlt_optimization: bool = True):
401397
tkg_neuron_config = NxDModelForCausalLM._create_token_generation_config(neuron_config)
402-
403-
return NxDDecoderBuilder(
398+
graph_builders["token_generation"] = NxDDecoderBuilder(
404399
config=config,
405400
neuron_config=tkg_neuron_config,
406401
max_tokens=tkg_neuron_config.sequence_length,
407402
active_tokens=1,
408403
model_cls=model_cls,
409-
tag=TOKEN_GENERATION_MODEL_TAG,
410-
priority_model_idx=0 if enable_wlt_optimization else None, # to turn on weight layout optimization
411-
)
412-
413-
@staticmethod
414-
def _create_speculation_builder(model_cls, config, neuron_config):
415-
spec_neuron_config = NxDModelForCausalLM._create_speculation_config(neuron_config)
416-
417-
return NxDDecoderBuilder(
418-
config=config,
419-
neuron_config=spec_neuron_config,
420-
max_tokens=spec_neuron_config.sequence_length,
421-
active_tokens=spec_neuron_config.speculation_length,
422-
model_cls=model_cls,
423-
tag=SPECULATION_MODEL_TAG,
424404
priority_model_idx=0, # to turn on weight layout optimization
425405
)
426-
427-
@staticmethod
428-
def create_model_builders(model_cls, config, neuron_config):
429-
model_builders = [
430-
NxDModelForCausalLM._create_context_encoding_builder(
431-
model_cls,
432-
config,
433-
neuron_config,
434-
),
435-
NxDModelForCausalLM._create_token_generation_builder(
436-
model_cls,
437-
config,
438-
neuron_config,
439-
),
440-
]
441406
if neuron_config.speculation_length > 0:
442-
model_builders.append(
443-
NxDModelForCausalLM._create_speculation_builder(
444-
model_cls,
445-
config,
446-
neuron_config,
447-
)
407+
spec_neuron_config = NxDModelForCausalLM._create_speculation_config(neuron_config)
408+
graph_builders["speculation_model"] = NxDDecoderBuilder(
409+
config=config,
410+
neuron_config=spec_neuron_config,
411+
max_tokens=spec_neuron_config.sequence_length,
412+
active_tokens=spec_neuron_config.speculation_length,
413+
model_cls=model_cls,
414+
priority_model_idx=0, # to turn on weight layout optimization
448415
)
449-
return model_builders
416+
return graph_builders
450417

451418
def forward(
452419
self,
@@ -650,7 +617,7 @@ def _from_pretrained(
650617
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
651618
else:
652619
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
653-
model_builders = NxDModelForCausalLM.create_model_builders(
620+
model_builders = NxDModelForCausalLM.create_graph_builders(
654621
cls._model_cls, config=config, neuron_config=neuron_config
655622
)
656623
model = cls(
@@ -707,7 +674,7 @@ def _export(
707674
# Evaluate head_dim if it is defined but set to null (like in Mixtral for transformers 4.54+)
708675
if hasattr(config, "head_dim") and config.head_dim is None:
709676
config.head_dim = config.hidden_size // config.num_attention_heads
710-
model_builders = cls.create_model_builders(
677+
model_builders = cls.create_graph_builders(
711678
model_cls=cls._model_cls,
712679
config=config,
713680
neuron_config=neuron_config,

optimum/neuron/models/inference/backend/modules/kvcache/kv_cache_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def get_cache(self, seq_len: int, skip_slice=False, **kwargs):
9393
"""
9494
Return network (all layers)'s previously cached K and V, up to seq_len.
9595
96-
:param seq_len: sequence length (or bucket size from auto-bucketing e.g. 128, 512, 1024 etc.)
96+
:param seq_len: sequence length
9797
:return: list of tuple of (K, V)
9898
"""
9999
slice_index, gather_index = None, None
@@ -129,9 +129,9 @@ def update_cache(
129129
:param scatter_index: tensor representing index to update
130130
:param is_for_context_encoding: bool
131131
:param seq_ids: tensor of size (batch_sz)
132-
:param position_ids: tensor of size (batch_sz, bucket_sz)
132+
:param position_ids: tensor of size (batch_sz, seq_len)
133133
:param new_key_values: list of tuple, the latest kv obtained at the end of the network from forward pass
134-
:param seq_len: sequence length (or bucket size from auto-bucketing e.g. 128, 512, 1024 etc.)
134+
:param seq_len: sequence length
135135
:param scatter_index: tensor representing index to update
136136
:param active_mask: tensor representing index to update
137137
:param kvcache_buffer: if passed key states are updates to this buffer.

optimum/neuron/models/inference/backend/pretrained_model.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_shards_path(dest_path):
4949

5050
def get_builder(
5151
neuron_config: NxDNeuronConfig,
52-
model_wrappers: list[NxDGraphBuilder],
52+
model_wrappers: dict[str, NxDGraphBuilder],
5353
debug: bool = False,
5454
checkpoint_loader=None,
5555
compiler_args: str = None,
@@ -86,9 +86,9 @@ def get_builder(
8686
logical_nc_config=neuron_config.logical_nc_config,
8787
weights_to_skip_layout_optimization=neuron_config.weights_to_skip_layout_optimization,
8888
)
89-
for model in model_wrappers:
89+
for tag, model in model_wrappers.items():
9090
builder.add(
91-
key=model.tag,
91+
key=tag,
9292
model_instance=model.get_model_instance(),
9393
example_inputs=model.input_generator(),
9494
compiler_args=compiler_args,
@@ -109,7 +109,7 @@ def __init__(
109109
config: PretrainedConfig,
110110
neuron_config: NxDNeuronConfig,
111111
traced_model: torch.jit.ScriptModule,
112-
model_wrappers: list[NxDGraphBuilder],
112+
model_wrappers: dict[str, NxDGraphBuilder],
113113
):
114114
self.config = copy.deepcopy(config)
115115
self.neuron_config = copy.deepcopy(neuron_config)
@@ -131,7 +131,7 @@ def get_compiler_args(cls, neuron_config) -> str | None:
131131
return None
132132

133133
@staticmethod
134-
def compile(neuron_config, model_wrappers: list[NxDGraphBuilder], compiler_args: str, debug: bool = False):
134+
def compile(neuron_config, model_wrappers: dict[str, NxDGraphBuilder], compiler_args: str, debug: bool = False):
135135
builder = get_builder(neuron_config, model_wrappers, debug=debug, compiler_args=compiler_args)
136136
return builder.trace(initialize_model_weights=False)
137137

0 commit comments

Comments
 (0)