55from transformers import Starcoder2Config
66
77from tensorrt_llm ._torch .attention_backend import AttentionMetadata
8- from tensorrt_llm ._torch .attention_backend .interface import (
9- PositionalEmbeddingParams , RopeParams )
8+ from tensorrt_llm ._torch .attention_backend .interface import PositionalEmbeddingParams , RopeParams
109from tensorrt_llm ._torch .model_config import ModelConfig
11- from tensorrt_llm ._torch .models .modeling_utils import (DecoderModel ,
12- DecoderModelForCausalLM ,
13- register_auto_model )
10+ from tensorrt_llm ._torch .models .modeling_utils import (
11+ DecoderModel ,
12+ DecoderModelForCausalLM ,
13+ register_auto_model ,
14+ )
1415from tensorrt_llm ._torch .modules .attention import Attention
1516from tensorrt_llm ._torch .modules .decoder_layer import DecoderLayer
1617from tensorrt_llm ._torch .modules .embedding import Embedding
17- from tensorrt_llm ._torch .modules .gated_mlp import GatedMLP
18- from tensorrt_llm ._torch .modules .linear import Linear , TensorParallelMode
18+ from tensorrt_llm ._torch .modules .linear import TensorParallelMode
1919from tensorrt_llm ._torch .modules .mlp import MLP
2020from tensorrt_llm ._torch .speculative import SpecMetadata
2121from tensorrt_llm .functional import PositionEmbeddingType
2424class Starcoder2LayerNorm (nn .LayerNorm ):
2525 """
2626 Custom LayerNorm that skips weight initialization to support meta tensor initialization.
27-
27+
2828 StarCoder2ForCausalLM inherits from DecoderModelForCausalLM which uses the PostInitCaller
2929 metaclass to enable meta tensor initialization (memory optimization). During model construction
3030 with meta tensors, PyTorch's nn.LayerNorm.reset_parameters() tries to initialize weights with
3131 ones_() which fails on meta tensors. This class skips that initialization step.
32-
32+
3333 The weights will be properly initialized later when loaded from the HuggingFace checkpoint.
3434 """
35-
35+
3636 def reset_parameters (self ) -> None :
3737 # Skip initialization operations that conflict with meta tensor initialization
3838 pass
@@ -63,10 +63,10 @@ def __init__(
6363 dtype = config .torch_dtype ,
6464 config = model_config ,
6565 )
66-
66+
6767 # Configure sliding window attention (4096 tokens)
68- self .attention_window_size = getattr (config , ' sliding_window' , 4096 )
69-
68+ self .attention_window_size = getattr (config , " sliding_window" , 4096 )
69+
7070 def forward (
7171 self ,
7272 position_ids : torch .IntTensor ,
@@ -89,7 +89,7 @@ def forward(
8989class Starcoder2DecoderLayer (DecoderLayer ):
9090 """
9191 StarCoder2 Decoder Layer.
92-
92+
9393 Architecture:
9494 - Layer normalization before attention (with bias)
9595 - Self-attention with GQA and sliding window
@@ -123,7 +123,7 @@ def __init__(
123123 else :
124124 raise ValueError (f"Unsupported mlp_type: { config .mlp_type } " )
125125
126- norm_eps = getattr (config , ' norm_epsilon' , 1e-5 )
126+ norm_eps = getattr (config , " norm_epsilon" , 1e-5 )
127127 self .input_layernorm = Starcoder2LayerNorm (
128128 config .hidden_size ,
129129 eps = norm_eps ,
@@ -149,8 +149,10 @@ def forward(
149149 residual = hidden_states
150150 hidden_states = self .input_layernorm (hidden_states )
151151 else :
152- hidden_states , residual = self .input_layernorm (
153- hidden_states + residual ), hidden_states + residual
152+ hidden_states , residual = (
153+ self .input_layernorm (hidden_states + residual ),
154+ hidden_states + residual ,
155+ )
154156
155157 # Self Attention
156158 hidden_states = self .self_attn (
@@ -165,11 +167,10 @@ def forward(
165167 residual = hidden_states
166168 hidden_states = self .post_attention_layernorm (hidden_states )
167169 hidden_states = self .mlp (hidden_states )
168-
170+
169171 if spec_metadata is not None :
170- spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
171- hidden_states , residual )
172-
172+ spec_metadata .maybe_capture_hidden_states (self .layer_idx , hidden_states , residual )
173+
173174 return hidden_states , residual
174175
175176
@@ -190,16 +191,19 @@ def __init__(self, model_config: ModelConfig[Starcoder2Config]):
190191 tensor_parallel_mode = TensorParallelMode .COLUMN ,
191192 gather_output = True ,
192193 )
193-
194- self .layers = nn .ModuleList ([
195- Starcoder2DecoderLayer (
196- model_config ,
197- layer_idx ,
198- ) for layer_idx in range (config .num_hidden_layers )
199- ])
200-
194+
195+ self .layers = nn .ModuleList (
196+ [
197+ Starcoder2DecoderLayer (
198+ model_config ,
199+ layer_idx ,
200+ )
201+ for layer_idx in range (config .num_hidden_layers )
202+ ]
203+ )
204+
201205 # Use norm_epsilon (Starcoder2Config attribute name)
202- norm_eps = getattr (config , ' norm_epsilon' , 1e-5 )
206+ norm_eps = getattr (config , " norm_epsilon" , 1e-5 )
203207 self .norm = Starcoder2LayerNorm (
204208 config .hidden_size ,
205209 eps = norm_eps ,
@@ -243,16 +247,16 @@ def forward(
243247
244248@register_auto_model ("Starcoder2ForCausalLM" )
245249class Starcoder2ForCausalLM (DecoderModelForCausalLM [Starcoder2Model , Starcoder2Config ]):
246-
247250 def __init__ (
248251 self ,
249252 model_config : ModelConfig [Starcoder2Config ],
250253 ):
251- # Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
254+ # Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
252255 # For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
253- if model_config .pretrained_config .torch_dtype is None or model_config .pretrained_config .torch_dtype == torch .float32 :
256+ torch_dtype_to_check = model_config .pretrained_config .torch_dtype
257+ if torch_dtype_to_check is None or torch_dtype_to_check == torch .float32 :
254258 model_config .pretrained_config .torch_dtype = torch .bfloat16
255-
259+
256260 super ().__init__ (
257261 Starcoder2Model (model_config ),
258262 config = model_config ,
@@ -263,27 +267,38 @@ def __init__(
263267 def load_weights (self , weights , weight_mapper = None , skip_modules = []):
264268 """
265269 Load weights with custom mapping for StarCoder2.
266-
270+
267271 StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
268272 while our MLP module expects (up_proj, down_proj).
269273 """
270274 # Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
271275 params_map = {
272- r' (.*?)\.mlp\.c_fc\.(.*)' : r' \1.mlp.up_proj.\2' ,
273- r' (.*?)\.mlp\.c_proj\.(.*)' : r' \1.mlp.down_proj.\2' ,
276+ r" (.*?)\.mlp\.c_fc\.(.*)" : r" \1.mlp.up_proj.\2" ,
277+ r" (.*?)\.mlp\.c_proj\.(.*)" : r" \1.mlp.down_proj.\2" ,
274278 }
275-
279+
276280 if weight_mapper is None :
277281 # Use _load_weights_impl for non-weight-mapper path
278282 from tensorrt_llm ._torch .models .modeling_utils import _load_weights_impl
283+
279284 preload_weight_modules = getattr (self , "preload_weight_modules" , None )
280- _load_weights_impl (self , weights , skip_modules ,
281- params_map = params_map ,
282- preload_weight_modules = preload_weight_modules )
285+ _load_weights_impl (
286+ self ,
287+ weights ,
288+ skip_modules ,
289+ params_map = params_map ,
290+ preload_weight_modules = preload_weight_modules ,
291+ )
283292 else :
284- # Use _load_weights_impl_v2 for weight-mapper path
293+ # Use _load_weights_impl_v2 for weight-mapper path
285294 from tensorrt_llm ._torch .models .modeling_utils import _load_weights_impl_v2
295+
286296 preload_weight_modules = getattr (self , "preload_weight_modules" , None )
287- _load_weights_impl_v2 (self , weights , weight_mapper , skip_modules ,
288- params_map = params_map ,
289- preload_weight_modules = preload_weight_modules )
297+ _load_weights_impl_v2 (
298+ self ,
299+ weights ,
300+ weight_mapper ,
301+ skip_modules ,
302+ params_map = params_map ,
303+ preload_weight_modules = preload_weight_modules ,
304+ )
0 commit comments