Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_seedoss import SeedOssForCausalLM
from .modeling_siglip import SiglipVisionModel
from .modeling_starcoder2 import Starcoder2ForCausalLM
from .modeling_utils import get_model_architecture
from .modeling_vila import VilaModel

Expand Down Expand Up @@ -62,6 +63,7 @@
"Qwen2ForRewardModel",
"Qwen2MoeForCausalLM",
"SiglipVisionModel",
"Starcoder2ForCausalLM",
"get_model_architecture",
"VilaModel",
"Qwen2VLModel",
Expand Down
287 changes: 287 additions & 0 deletions tensorrt_llm/_torch/models/modeling_starcoder2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from torch import nn
from transformers import Starcoder2Config

from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (
DecoderModel,
DecoderModelForCausalLM,
_load_weights_impl,
register_auto_model,
)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.layer_norm import LayerNorm
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm._torch.modules.mlp import MLP
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType


class Starcoder2Attention(Attention):
"""
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
"""

def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.use_bias,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=RopeParams.from_config(config),
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)

# Configure sliding window attention (4096 tokens)
self.attention_window_size = getattr(config, "sliding_window", 4096)

def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
"""
Overrides parent to pass attention_window_size parameter.
"""
return super().forward(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_window_size=self.attention_window_size,
**kwargs,
)


class Starcoder2DecoderLayer(DecoderLayer):
"""
StarCoder2 Decoder Layer.

Architecture:
- Layer normalization before attention (with bias)
- Self-attention with GQA and sliding window
- Layer normalization before MLP (with bias)
- MLP with GELU activation
"""

def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
layer_idx: int,
):
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx

self.self_attn = Starcoder2Attention(
model_config,
layer_idx=layer_idx,
)

if config.mlp_type == "default":
self.mlp = MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=config.use_bias,
activation=nn.GELU(),
dtype=config.torch_dtype,
config=model_config,
)
else:
raise ValueError(
f"Unsupported mlp_type: {config.mlp_type}. Only default (linear) MLP is supported."
)

norm_eps = getattr(config, "norm_epsilon", 1e-5)
self.input_layernorm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)

self.post_attention_layernorm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)

def forward(
self,
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
):
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)

# Self Attention
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)

# Fully Connected (MLP)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)

if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx, hidden_states, residual)

return hidden_states, residual


class Starcoder2Model(DecoderModel):
"""
StarCoder2 Transformer Model.
"""

def __init__(self, model_config: ModelConfig[Starcoder2Config]):
super().__init__(model_config)
config = self.model_config.pretrained_config

self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)

self.layers = nn.ModuleList(
[
Starcoder2DecoderLayer(
model_config,
layer_idx,
)
for layer_idx in range(config.num_hidden_layers)
]
)

# Use norm_epsilon (Starcoder2Config attribute name)
norm_eps = getattr(config, "norm_epsilon", 1e-5)
self.norm = LayerNorm(
hidden_size=config.hidden_size,
eps=norm_eps,
dtype=config.torch_dtype,
has_bias=True, # StarCoder2 uses bias in layer norm
)

def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
lora_params=None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds.")

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

hidden_states = inputs_embeds

residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
spec_metadata=spec_metadata,
lora_params=lora_params,
)

# Use LayerNorm's built-in residual connection support
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


@register_auto_model("Starcoder2ForCausalLM")
class Starcoder2ForCausalLM(DecoderModelForCausalLM[Starcoder2Model, Starcoder2Config]):
def __init__(
self,
model_config: ModelConfig[Starcoder2Config],
):
# Ensure torch_dtype is set on pretrained_config (StarCoder2 uses bfloat16).
# For the 15B FP32 checkpoint, we cast it to bfloat16 for consistency.
torch_dtype_to_check = model_config.pretrained_config.torch_dtype
if torch_dtype_to_check is None or torch_dtype_to_check == torch.float32:
model_config.pretrained_config.torch_dtype = torch.bfloat16

super().__init__(
Starcoder2Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size,
)

def load_weights(self, weights, weight_mapper=None, skip_modules=None):
"""
Load weights with custom mapping for StarCoder2.

StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
while our MLP module expects (up_proj, down_proj).
"""
if skip_modules is None:
skip_modules = []

# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
params_map = {
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
}
preload_weight_modules = getattr(self, "preload_weight_modules", None)
_load_weights_impl(
self,
weights,
skip_modules,
params_map=params_map,
preload_weight_modules=preload_weight_modules,
)
6 changes: 6 additions & 0 deletions tests/integration/defs/accuracy/references/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,9 @@ zai-org/GLM-4.6:
- quant_algo: NVFP4
spec_dec_algo: MTP
accuracy: 88.0
bigcode/starcoder2-3b:
- accuracy: 20.2
bigcode/starcoder2-7b:
- accuracy: 26.5
bigcode/starcoder2-15b:
- accuracy: 54.5
46 changes: 46 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4258,3 +4258,49 @@ def test_nvfp4_4gpus(self):
if temp_dir and os.path.exists(temp_dir):
import shutil
shutil.rmtree(temp_dir, ignore_errors=True)


class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-3b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"

@skip_pre_hopper
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-7b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"

@skip_pre_hopper
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
MODEL_NAME = "bigcode/starcoder2-15b"
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"

@skip_pre_hopper
@pytest.mark.skip_less_device_memory(80000)
def test_auto_dtype(self):
with LLM(self.MODEL_PATH,
attn_backend="TRTLLM",
cuda_graph_config=None,
max_batch_size=128,
max_seq_len=4096) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
4 changes: 4 additions & 0 deletions tests/integration/test_lists/qa/llm_function_nim.txt
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype

accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_a30.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ l0_a30:
- unittest/_torch/modeling -k "modeling_qwen"
- unittest/_torch/modeling -k "modeling_qwen_moe"
- unittest/_torch/modeling -k "modeling_out_of_tree"
- unittest/_torch/modeling -k "modeling_starcoder2"
- unittest/_torch/auto_deploy/unit/singlegpu
- unittest/_torch/sampler/test_beam_search.py
- unittest/_torch/sampler/test_return_logits.py
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ l0_h100:
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]
Expand Down
Loading