Skip to content

Commit adbefc2

Browse files
committed
adding Starcoder2 PyTorch flow support
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent 5a01f38 commit adbefc2

File tree

2 files changed

+270
-0
lines changed

2 files changed

+270
-0
lines changed

tensorrt_llm/_torch/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .modeling_qwen_moe import Qwen2MoeForCausalLM
3131
from .modeling_seedoss import SeedOssForCausalLM
3232
from .modeling_siglip import SiglipVisionModel
33+
from .modeling_starcoder2 import Starcoder2ForCausalLM
3334
from .modeling_utils import get_model_architecture
3435
from .modeling_vila import VilaModel
3536

@@ -61,6 +62,7 @@
6162
"Qwen2ForRewardModel",
6263
"Qwen2MoeForCausalLM",
6364
"SiglipVisionModel",
65+
"Starcoder2ForCausalLM",
6466
"get_model_architecture",
6567
"VilaModel",
6668
"Qwen2VLModel",
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
from typing import Optional
2+
3+
import torch
4+
from torch import nn
5+
from transformers import Starcoder2Config
6+
7+
from tensorrt_llm._torch.attention_backend import AttentionMetadata
8+
from tensorrt_llm._torch.attention_backend.interface import (
9+
PositionalEmbeddingParams, RopeParams)
10+
from tensorrt_llm._torch.model_config import ModelConfig
11+
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
12+
DecoderModelForCausalLM,
13+
register_auto_model)
14+
from tensorrt_llm._torch.modules.attention import Attention
15+
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
16+
from tensorrt_llm._torch.modules.embedding import Embedding
17+
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
18+
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
19+
from tensorrt_llm._torch.modules.mlp import MLP
20+
from tensorrt_llm._torch.speculative import SpecMetadata
21+
from tensorrt_llm.functional import PositionEmbeddingType
22+
23+
24+
class Starcoder2LayerNorm(nn.LayerNorm):
25+
"""
26+
Custom LayerNorm that skips weight initialization to support meta tensor initialization.
27+
28+
StarCoder2ForCausalLM inherits from DecoderModelForCausalLM which uses the PostInitCaller
29+
metaclass to enable meta tensor initialization (memory optimization). During model construction
30+
with meta tensors, PyTorch's nn.LayerNorm.reset_parameters() tries to initialize weights with
31+
ones_() which fails on meta tensors. This class skips that initialization step.
32+
33+
The weights will be properly initialized later when loaded from the HuggingFace checkpoint.
34+
"""
35+
36+
def reset_parameters(self) -> None:
37+
# Skip initialization operations that conflict with meta tensor initialization
38+
pass
39+
40+
41+
class Starcoder2Attention(Attention):
42+
"""
43+
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
44+
"""
45+
46+
def __init__(
47+
self,
48+
model_config: ModelConfig[Starcoder2Config],
49+
layer_idx: Optional[int] = None,
50+
):
51+
config = model_config.pretrained_config
52+
super().__init__(
53+
hidden_size=config.hidden_size,
54+
num_attention_heads=config.num_attention_heads,
55+
num_key_value_heads=config.num_key_value_heads,
56+
max_position_embeddings=config.max_position_embeddings,
57+
bias=config.use_bias,
58+
pos_embd_params=PositionalEmbeddingParams(
59+
type=PositionEmbeddingType.rope_gpt_neox,
60+
rope=RopeParams.from_config(config),
61+
),
62+
layer_idx=layer_idx,
63+
dtype=config.torch_dtype,
64+
config=model_config,
65+
)
66+
67+
68+
class Starcoder2DecoderLayer(DecoderLayer):
69+
"""
70+
StarCoder2 Decoder Layer.
71+
72+
Architecture:
73+
- Layer normalization before attention (with bias)
74+
- Self-attention with GQA and sliding window
75+
- Layer normalization before MLP (with bias)
76+
- MLP with GELU activation
77+
"""
78+
79+
def __init__(
80+
self,
81+
model_config: ModelConfig[Starcoder2Config],
82+
layer_idx: int,
83+
):
84+
super().__init__()
85+
config = model_config.pretrained_config
86+
self.layer_idx = layer_idx
87+
88+
self.self_attn = Starcoder2Attention(
89+
model_config,
90+
layer_idx=layer_idx,
91+
)
92+
93+
if config.mlp_type == "default":
94+
self.mlp = MLP(
95+
hidden_size=config.hidden_size,
96+
intermediate_size=config.intermediate_size,
97+
bias=config.use_bias,
98+
activation=nn.GELU(),
99+
dtype=config.torch_dtype,
100+
config=model_config,
101+
)
102+
else:
103+
raise ValueError(f"Unsupported mlp_type: {config.mlp_type}")
104+
105+
norm_eps = getattr(config, 'norm_epsilon', 1e-5)
106+
self.input_layernorm = Starcoder2LayerNorm(
107+
config.hidden_size,
108+
eps=norm_eps,
109+
dtype=config.torch_dtype,
110+
)
111+
112+
self.post_attention_layernorm = Starcoder2LayerNorm(
113+
config.hidden_size,
114+
eps=norm_eps,
115+
dtype=config.torch_dtype,
116+
)
117+
118+
def forward(
119+
self,
120+
position_ids: torch.IntTensor,
121+
hidden_states: torch.Tensor,
122+
attn_metadata: AttentionMetadata,
123+
residual: Optional[torch.Tensor] = None,
124+
spec_metadata: Optional[SpecMetadata] = None,
125+
**kwargs,
126+
):
127+
if residual is None:
128+
residual = hidden_states
129+
hidden_states = self.input_layernorm(hidden_states)
130+
else:
131+
hidden_states, residual = self.input_layernorm(
132+
hidden_states + residual), hidden_states + residual
133+
134+
# Self Attention
135+
hidden_states = self.self_attn(
136+
position_ids=position_ids,
137+
hidden_states=hidden_states,
138+
attn_metadata=attn_metadata,
139+
**kwargs,
140+
)
141+
142+
# Fully Connected (MLP)
143+
hidden_states = hidden_states + residual
144+
residual = hidden_states
145+
hidden_states = self.post_attention_layernorm(hidden_states)
146+
hidden_states = self.mlp(hidden_states)
147+
148+
if spec_metadata is not None:
149+
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
150+
hidden_states, residual)
151+
152+
return hidden_states, residual
153+
154+
155+
class Starcoder2Model(DecoderModel):
156+
"""
157+
StarCoder2 Transformer Model.
158+
"""
159+
160+
def __init__(self, model_config: ModelConfig[Starcoder2Config]):
161+
super().__init__(model_config)
162+
config = self.model_config.pretrained_config
163+
164+
self.embed_tokens = Embedding(
165+
config.vocab_size,
166+
config.hidden_size,
167+
dtype=config.torch_dtype,
168+
mapping=model_config.mapping,
169+
tensor_parallel_mode=TensorParallelMode.COLUMN,
170+
gather_output=True,
171+
)
172+
173+
self.layers = nn.ModuleList([
174+
Starcoder2DecoderLayer(
175+
model_config,
176+
layer_idx,
177+
) for layer_idx in range(config.num_hidden_layers)
178+
])
179+
180+
# Use norm_epsilon (Starcoder2Config attribute name)
181+
norm_eps = getattr(config, 'norm_epsilon', 1e-5)
182+
self.norm = Starcoder2LayerNorm(
183+
config.hidden_size,
184+
eps=norm_eps,
185+
dtype=config.torch_dtype,
186+
)
187+
188+
def forward(
189+
self,
190+
attn_metadata: AttentionMetadata,
191+
input_ids: Optional[torch.IntTensor] = None,
192+
position_ids: Optional[torch.IntTensor] = None,
193+
inputs_embeds: Optional[torch.FloatTensor] = None,
194+
spec_metadata: Optional[SpecMetadata] = None,
195+
lora_params=None,
196+
) -> torch.Tensor:
197+
if (input_ids is None) ^ (inputs_embeds is not None):
198+
raise ValueError(
199+
"You cannot specify both input_ids and inputs_embeds at the same time, "
200+
"and must specify either one"
201+
)
202+
203+
if inputs_embeds is None:
204+
inputs_embeds = self.embed_tokens(input_ids)
205+
206+
hidden_states = inputs_embeds
207+
208+
residual = None
209+
for decoder_layer in self.layers:
210+
hidden_states, residual = decoder_layer(
211+
position_ids=position_ids,
212+
hidden_states=hidden_states,
213+
attn_metadata=attn_metadata,
214+
residual=residual,
215+
spec_metadata=spec_metadata,
216+
lora_params=lora_params,
217+
)
218+
219+
hidden_states = self.norm(hidden_states + residual)
220+
return hidden_states
221+
222+
223+
@register_auto_model("Starcoder2ForCausalLM")
224+
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
225+
226+
def __init__(
227+
self,
228+
model_config: ModelConfig[Starcoder2Config],
229+
):
230+
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16)
231+
if model_config.pretrained_config.torch_dtype is None:
232+
model_config.pretrained_config.torch_dtype = torch.bfloat16
233+
234+
super().__init__(
235+
Starcoder2Model(model_config),
236+
config=model_config,
237+
hidden_size=model_config.pretrained_config.hidden_size,
238+
vocab_size=model_config.pretrained_config.vocab_size,
239+
)
240+
241+
def load_weights(self, weights, weight_mapper=None, skip_modules=[]):
242+
"""
243+
Load weights with custom mapping for StarCoder2.
244+
245+
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
246+
while our MLP module expects (up_proj, down_proj).
247+
"""
248+
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
249+
params_map = {
250+
r'(.*?)\.mlp\.c_fc\.(.*)': r'\1.mlp.up_proj.\2',
251+
r'(.*?)\.mlp\.c_proj\.(.*)': r'\1.mlp.down_proj.\2',
252+
}
253+
254+
if weight_mapper is None:
255+
# Use _load_weights_impl for non-weight-mapper path
256+
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl
257+
preload_weight_modules = getattr(self, "preload_weight_modules", None)
258+
_load_weights_impl(self, weights, skip_modules,
259+
params_map=params_map,
260+
preload_weight_modules=preload_weight_modules)
261+
else:
262+
# Use _load_weights_impl_v2 for weight-mapper path
263+
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl_v2
264+
preload_weight_modules = getattr(self, "preload_weight_modules", None)
265+
_load_weights_impl_v2(self, weights, weight_mapper, skip_modules,
266+
params_map=params_map,
267+
preload_weight_modules=preload_weight_modules)
268+

0 commit comments

Comments
 (0)