Skip to content

Commit 08bc221

Browse files
committed
refactor(nxd): split NxDModelWrapper
1 parent 196afc1 commit 08bc221

File tree

6 files changed

+317
-251
lines changed

6 files changed

+317
-251
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from abc import ABC, abstractmethod
16+
17+
import torch
18+
from neuronx_distributed.trace.model_builder import BaseModelInstance
19+
from torch_neuronx import BucketModelConfig
20+
21+
22+
class NxDGraphBuilder(ABC):
23+
def __init__(self, tag: str, priority_model_idx: int):
24+
super().__init__()
25+
self.tag = tag
26+
self.priority_model_idx = priority_model_idx
27+
28+
@abstractmethod
29+
def input_generator(self) -> list[torch.Tensor]:
30+
"""Return the list of the model input tensors
31+
32+
Used at compilation time only when tracing the model.
33+
"""
34+
raise NotImplementedError
35+
36+
@abstractmethod
37+
def get_model_instance(self) -> BaseModelInstance:
38+
"""Return the underlying ModelInstance
39+
40+
Used at compilation time only when tracing the model.
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def get_bucket_config(self) -> BucketModelConfig | None:
46+
"""Return the bucket configuration
47+
48+
Used at compilation time only when tracing the model.
49+
"""
50+
raise NotImplementedError

optimum/neuron/models/inference/backend/model_wrapper.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,8 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from abc import abstractmethod
16-
1715
import torch
18-
from neuronx_distributed.trace.model_builder import BaseModelInstance
19-
from torch_neuronx import BucketModelConfig
2016

2117

2218
class NxDModelWrapper(torch.nn.Module):
23-
def __init__(self, tag: str, priority_model_idx: int):
24-
super().__init__()
25-
self.tag = tag
26-
self.priority_model_idx = priority_model_idx
27-
28-
@abstractmethod
29-
def input_generator(self) -> list[torch.Tensor]:
30-
"""Return the list of the model input tensors
31-
32-
Used at compilation time only when tracing the model.
33-
"""
34-
raise NotImplementedError
35-
36-
@abstractmethod
37-
def get_model_instance(self) -> BaseModelInstance:
38-
"""Return the underlying ModelInstance
39-
40-
Used at compilation time only when tracing the model.
41-
"""
42-
raise NotImplementedError
43-
44-
@abstractmethod
45-
def get_bucket_config(self) -> BucketModelConfig | None:
46-
"""Return the bucket configuration
47-
48-
Used at compilation time only when tracing the model.
49-
"""
50-
raise NotImplementedError
51-
52-
@abstractmethod
53-
def forward(self, *args) -> list[torch.Tensor]:
54-
raise NotImplementedError
19+
pass
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
from neuronx_distributed.trace.model_builder import BaseModelInstance
18+
from torch_neuronx import BucketModelConfig
19+
from transformers import PretrainedConfig
20+
21+
from ...config import NxDNeuronConfig
22+
from ...graph_builder import NxDGraphBuilder
23+
from ..autobucketing import (
24+
get_context_encoder_bk,
25+
get_generation_model_bk,
26+
)
27+
from ..generation.sampling import prepare_sampling_params
28+
29+
30+
CONTEXT_ENCODING_MODEL_TAG = "context_encoding_model"
31+
TOKEN_GENERATION_MODEL_TAG = "token_generation_model"
32+
SPECULATION_MODEL_TAG = "speculation_model"
33+
34+
35+
def get_bucket_model_config_from_tag(
36+
tag, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: list[int]
37+
):
38+
bucket_degree = len(buckets)
39+
if bucket_degree == 1:
40+
return None
41+
42+
pad_token = config.pad_token_id
43+
44+
# NOTE: KV Cache preprocessing is done within the model and not the
45+
# shared buffer preprocessor due to lack of support of non-contiguous
46+
# slicing of nrt tensors via the NRT API.
47+
if tag == CONTEXT_ENCODING_MODEL_TAG:
48+
return BucketModelConfig(
49+
bucket_kernel=get_context_encoder_bk,
50+
bucket_kernel_constant_args=(
51+
torch.tensor(buckets),
52+
pad_token,
53+
),
54+
shared_state_buffer=None,
55+
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
56+
)
57+
elif tag == TOKEN_GENERATION_MODEL_TAG or tag == SPECULATION_MODEL_TAG:
58+
return BucketModelConfig(
59+
bucket_kernel=get_generation_model_bk,
60+
bucket_kernel_constant_args=(
61+
torch.tensor(buckets),
62+
0,
63+
),
64+
shared_state_buffer=None,
65+
func_kwargs=[{"bucket_rank": i} for i in range(bucket_degree)],
66+
)
67+
else:
68+
raise ValueError(
69+
f"The supplied tag: {tag} is not supported for Bucketing. Only {CONTEXT_ENCODING_MODEL_TAG} and {TOKEN_GENERATION_MODEL_TAG} are supported"
70+
)
71+
72+
73+
class NxDDecoderBuilder(NxDGraphBuilder):
74+
def __init__(
75+
self,
76+
config: PretrainedConfig,
77+
neuron_config: NxDNeuronConfig,
78+
buckets: list[int],
79+
bucket_n_active_tokens: bool,
80+
model_cls,
81+
tag="",
82+
priority_model_idx: int = None,
83+
) -> None:
84+
super().__init__(tag, priority_model_idx)
85+
self.config = config
86+
self.neuron_config = neuron_config
87+
self.buckets = buckets
88+
self.bucket_n_active_tokens = bucket_n_active_tokens
89+
90+
if not self.neuron_config.torch_dtype:
91+
self.neuron_config.torch_dtype = torch.float32
92+
93+
if config.pad_token_id is None:
94+
config.pad_token_id = 0
95+
96+
self.model_cls = model_cls
97+
98+
def input_generator(
99+
self,
100+
):
101+
inputs = []
102+
for bucket in self.buckets:
103+
n_active_tokens = bucket if self.bucket_n_active_tokens else self.neuron_config.n_active_tokens
104+
105+
input_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
106+
attention_mask = torch.zeros((self.neuron_config.batch_size, bucket), dtype=torch.int32)
107+
position_ids = torch.zeros((self.neuron_config.batch_size, n_active_tokens), dtype=torch.int32)
108+
seq_ids = torch.zeros((self.neuron_config.batch_size), dtype=torch.int32)
109+
# Get the count of sampling params currently supported.
110+
sampling_params_len = prepare_sampling_params(1).shape[1]
111+
sampling_params = torch.zeros((self.neuron_config.batch_size, sampling_params_len), dtype=torch.float32)
112+
113+
inputs.append((input_ids, attention_mask, position_ids, seq_ids, sampling_params))
114+
115+
return inputs
116+
117+
def get_model_instance(self):
118+
return DecoderModelInstance(
119+
model_cls=self.model_cls,
120+
config=self.config,
121+
neuron_config=self.neuron_config,
122+
buckets=self.buckets,
123+
)
124+
125+
def get_bucket_config(self):
126+
return get_bucket_model_config_from_tag(self.tag, self.config, self.neuron_config, self.buckets)
127+
128+
129+
class DecoderModelInstance(BaseModelInstance):
130+
def __init__(self, model_cls, config: PretrainedConfig, neuron_config: NxDNeuronConfig, buckets: list[int]):
131+
self.model_cls = model_cls
132+
self.module = None
133+
self.input_output_aliases = None
134+
self.config = config
135+
self.neuron_config = neuron_config
136+
self.buckets = buckets
137+
138+
def initialize_process_group(self, world_size):
139+
self.model_cls.initialize_process_group(world_size)
140+
141+
def load_module(self):
142+
float_model = self.model_cls(self.config, self.neuron_config)
143+
float_model.eval()
144+
145+
if self.neuron_config.torch_dtype != torch.float32:
146+
float_model._apply(
147+
lambda t: t.to(self.neuron_config.torch_dtype)
148+
if t.is_floating_point() and t.dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]
149+
else t
150+
)
151+
self.module = float_model
152+
153+
def get(self, bucket_rank, **kwargs):
154+
if bucket_rank is not None:
155+
self.module.n_positions = self.buckets[bucket_rank]
156+
157+
# Currently we have to init an input_output_aliases map for
158+
# each buckets, otherwise it will fail the aliasing setup when
159+
# generating HLO
160+
self.input_output_aliases = {}
161+
num_output_from_trace = 1 if not self.neuron_config.output_logits else 2
162+
# TODO: This else block is a short-term fix for Llava/ViT models to use DecoderModelInstance.
163+
# Long-term, these models should use a different implementation of BaseModelInstance.
164+
if self.module.kv_mgr is not None:
165+
past_key_values = self.module.kv_mgr.past_key_values
166+
else:
167+
past_key_values = self.module.past_key_values
168+
for i in range(len(past_key_values)):
169+
self.input_output_aliases[past_key_values[i]] = num_output_from_trace + i
170+
return self.module, self.input_output_aliases

0 commit comments

Comments
 (0)