Skip to content

Commit 9129c6e

Browse files
committed
pre-commit
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent fab55d3 commit 9129c6e

File tree

2 files changed

+113
-125
lines changed

2 files changed

+113
-125
lines changed

tensorrt_llm/_torch/models/modeling_starcoder2.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
from transformers import Starcoder2Config
66

77
from tensorrt_llm._torch.attention_backend import AttentionMetadata
8-
from tensorrt_llm._torch.attention_backend.interface import (
9-
PositionalEmbeddingParams, RopeParams)
8+
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
109
from tensorrt_llm._torch.model_config import ModelConfig
11-
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
12-
DecoderModelForCausalLM,
13-
register_auto_model)
10+
from tensorrt_llm._torch.models.modeling_utils import (
11+
DecoderModel,
12+
DecoderModelForCausalLM,
13+
register_auto_model,
14+
)
1415
from tensorrt_llm._torch.modules.attention import Attention
1516
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
1617
from tensorrt_llm._torch.modules.embedding import Embedding
17-
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
18-
from tensorrt_llm._torch.modules.linear import Linear, TensorParallelMode
18+
from tensorrt_llm._torch.modules.linear import TensorParallelMode
1919
from tensorrt_llm._torch.modules.mlp import MLP
2020
from tensorrt_llm._torch.speculative import SpecMetadata
2121
from tensorrt_llm.functional import PositionEmbeddingType
@@ -24,15 +24,15 @@
2424
class Starcoder2LayerNorm(nn.LayerNorm):
2525
"""
2626
Custom LayerNorm that skips weight initialization to support meta tensor initialization.
27-
27+
2828
StarCoder2ForCausalLM inherits from DecoderModelForCausalLM which uses the PostInitCaller
2929
metaclass to enable meta tensor initialization (memory optimization). During model construction
3030
with meta tensors, PyTorch's nn.LayerNorm.reset_parameters() tries to initialize weights with
3131
ones_() which fails on meta tensors. This class skips that initialization step.
32-
32+
3333
The weights will be properly initialized later when loaded from the HuggingFace checkpoint.
3434
"""
35-
35+
3636
def reset_parameters(self) -> None:
3737
# Skip initialization operations that conflict with meta tensor initialization
3838
pass
@@ -63,10 +63,10 @@ def __init__(
6363
dtype=config.torch_dtype,
6464
config=model_config,
6565
)
66-
66+
6767
# Configure sliding window attention (4096 tokens)
68-
self.attention_window_size = getattr(config, 'sliding_window', 4096)
69-
68+
self.attention_window_size = getattr(config, "sliding_window", 4096)
69+
7070
def forward(
7171
self,
7272
position_ids: torch.IntTensor,
@@ -89,7 +89,7 @@ def forward(
8989
class Starcoder2DecoderLayer(DecoderLayer):
9090
"""
9191
StarCoder2 Decoder Layer.
92-
92+
9393
Architecture:
9494
- Layer normalization before attention (with bias)
9595
- Self-attention with GQA and sliding window
@@ -123,7 +123,7 @@ def __init__(
123123
else:
124124
raise ValueError(f"Unsupported mlp_type: {config.mlp_type}")
125125

126-
norm_eps = getattr(config, 'norm_epsilon', 1e-5)
126+
norm_eps = getattr(config, "norm_epsilon", 1e-5)
127127
self.input_layernorm = Starcoder2LayerNorm(
128128
config.hidden_size,
129129
eps=norm_eps,
@@ -149,8 +149,10 @@ def forward(
149149
residual = hidden_states
150150
hidden_states = self.input_layernorm(hidden_states)
151151
else:
152-
hidden_states, residual = self.input_layernorm(
153-
hidden_states + residual), hidden_states + residual
152+
hidden_states, residual = (
153+
self.input_layernorm(hidden_states + residual),
154+
hidden_states + residual,
155+
)
154156

155157
# Self Attention
156158
hidden_states = self.self_attn(
@@ -165,11 +167,10 @@ def forward(
165167
residual = hidden_states
166168
hidden_states = self.post_attention_layernorm(hidden_states)
167169
hidden_states = self.mlp(hidden_states)
168-
170+
169171
if spec_metadata is not None:
170-
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
171-
hidden_states, residual)
172-
172+
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)
173+
173174
return hidden_states, residual
174175

175176

@@ -190,16 +191,19 @@ def __init__(self, model_config: ModelConfig[Starcoder2Config]):
190191
tensor_parallel_mode=TensorParallelMode.COLUMN,
191192
gather_output=True,
192193
)
193-
194-
self.layers = nn.ModuleList([
195-
Starcoder2DecoderLayer(
196-
model_config,
197-
layer_idx,
198-
) for layer_idx in range(config.num_hidden_layers)
199-
])
200-
194+
195+
self.layers = nn.ModuleList(
196+
[
197+
Starcoder2DecoderLayer(
198+
model_config,
199+
layer_idx,
200+
)
201+
for layer_idx in range(config.num_hidden_layers)
202+
]
203+
)
204+
201205
# Use norm_epsilon (Starcoder2Config attribute name)
202-
norm_eps = getattr(config, 'norm_epsilon', 1e-5)
206+
norm_eps = getattr(config, "norm_epsilon", 1e-5)
203207
self.norm = Starcoder2LayerNorm(
204208
config.hidden_size,
205209
eps=norm_eps,
@@ -243,16 +247,16 @@ def forward(
243247

244248
@register_auto_model("Starcoder2ForCausalLM")
245249
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
246-
247250
def __init__(
248251
self,
249252
model_config: ModelConfig[Starcoder2Config],
250253
):
251-
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
254+
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
252255
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
253-
if model_config.pretrained_config.torch_dtype is None or model_config.pretrained_config.torch_dtype == torch.float32:
256+
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
257+
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
254258
model_config.pretrained_config.torch_dtype = torch.bfloat16
255-
259+
256260
super().__init__(
257261
Starcoder2Model(model_config),
258262
config=model_config,
@@ -263,27 +267,38 @@ def __init__(
263267
def load_weights(self, weights, weight_mapper=None, skip_modules=[]):
264268
"""
265269
Load weights with custom mapping for StarCoder2.
266-
270+
267271
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
268272
while our MLP module expects (up_proj, down_proj).
269273
"""
270274
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
271275
params_map = {
272-
r'(.*?)\.mlp\.c_fc\.(.*)': r'\1.mlp.up_proj.\2',
273-
r'(.*?)\.mlp\.c_proj\.(.*)': r'\1.mlp.down_proj.\2',
276+
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
277+
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
274278
}
275-
279+
276280
if weight_mapper is None:
277281
# Use _load_weights_impl for non-weight-mapper path
278282
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl
283+
279284
preload_weight_modules = getattr(self, "preload_weight_modules", None)
280-
_load_weights_impl(self, weights, skip_modules,
281-
params_map=params_map,
282-
preload_weight_modules=preload_weight_modules)
285+
_load_weights_impl(
286+
self,
287+
weights,
288+
skip_modules,
289+
params_map=params_map,
290+
preload_weight_modules=preload_weight_modules,
291+
)
283292
else:
284-
# Use _load_weights_impl_v2 for weight-mapper path
293+
# Use _load_weights_impl_v2 for weight-mapper path
285294
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl_v2
295+
286296
preload_weight_modules = getattr(self, "preload_weight_modules", None)
287-
_load_weights_impl_v2(self, weights, weight_mapper, skip_modules,
288-
params_map=params_map,
289-
preload_weight_modules=preload_weight_modules)
297+
_load_weights_impl_v2(
298+
self,
299+
weights,
300+
weight_mapper,
301+
skip_modules,
302+
params_map=params_map,
303+
preload_weight_modules=preload_weight_modules,
304+
)

0 commit comments

Comments
 (0)