Skip to content

Commit 7d135ba

Browse files
committed
add integration accuracy tests and clean up coderabbit suggestion
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent a60830b commit 7d135ba

File tree

5 files changed

+89
-20
lines changed

5 files changed

+89
-20
lines changed

tensorrt_llm/_torch/models/modeling_starcoder2.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
116
from typing import Optional
217

318
import torch
@@ -264,13 +279,16 @@ def __init__(
264279
vocab_size=model_config.pretrained_config.vocab_size,
265280
)
266281

267-
def load_weights(self, weights, weight_mapper=None, skip_modules=[]):
282+
def load_weights(self, weights, weight_mapper=None, skip_modules=None):
268283
"""
269284
Load weights with custom mapping for StarCoder2.
270285
271286
StarCoder2 uses GPT-2 style MLP naming (c_fc, c_proj)
272287
while our MLP module expects (up_proj, down_proj).
273288
"""
289+
if skip_modules is None:
290+
skip_modules = []
291+
274292
# Map HuggingFace StarCoder2 weight names to TensorRT-LLM names
275293
params_map = {
276294
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4054,3 +4054,49 @@ def test_auto_dtype(self):
40544054
extra_evaluator_kwargs=dict(
40554055
apply_chat_template=True,
40564056
chat_template_kwargs=chat_template_kwargs))
4057+
4058+
4059+
class TestStarcoder2_3B(LlmapiAccuracyTestHarness):
4060+
MODEL_NAME = "bigcode/starcoder2-3b"
4061+
MODEL_PATH = f"{llm_models_root()}/starcoder2-3b/"
4062+
4063+
@skip_pre_hopper
4064+
def test_auto_dtype(self):
4065+
with LLM(self.MODEL_PATH,
4066+
attn_backend="TRTLLM",
4067+
cuda_graph_config=None,
4068+
max_batch_size=128,
4069+
max_seq_len=4096) as llm:
4070+
task = GSM8K(self.MODEL_NAME)
4071+
task.evaluate(llm)
4072+
4073+
4074+
class TestStarcoder2_7B(LlmapiAccuracyTestHarness):
4075+
MODEL_NAME = "bigcode/starcoder2-7b"
4076+
MODEL_PATH = f"{llm_models_root()}/starcoder2-7b/"
4077+
4078+
@skip_pre_hopper
4079+
def test_auto_dtype(self):
4080+
with LLM(self.MODEL_PATH,
4081+
attn_backend="TRTLLM",
4082+
cuda_graph_config=None,
4083+
max_batch_size=128,
4084+
max_seq_len=4096) as llm:
4085+
task = GSM8K(self.MODEL_NAME)
4086+
task.evaluate(llm)
4087+
4088+
4089+
class TestStarcoder2_15B(LlmapiAccuracyTestHarness):
4090+
MODEL_NAME = "bigcode/starcoder2-15b"
4091+
MODEL_PATH = f"{llm_models_root()}/starcoder2-15b/"
4092+
4093+
@skip_pre_hopper
4094+
@pytest.mark.skip_less_device_memory(80000)
4095+
def test_auto_dtype(self):
4096+
with LLM(self.MODEL_PATH,
4097+
attn_backend="TRTLLM",
4098+
cuda_graph_config=None,
4099+
max_batch_size=128,
4100+
max_seq_len=4096) as llm:
4101+
task = GSM8K(self.MODEL_NAME)
4102+
task.evaluate(llm)

tests/integration/test_lists/qa/llm_function_nim.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-c
382382
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4
383383
accuracy/test_llm_api_pytorch.py::TestCodestral_22B_V01::test_auto_dtype
384384
accuracy/test_llm_api_pytorch.py::TestKimiK2::test_fp8_blockscale[latency]
385+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
386+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
387+
accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
388+
385389
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
386390
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_5_VL_7B::test_auto_dtype
387391
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ l0_h100:
259259
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_eagle3[llguidance-eagle3_one_model=False]
260260
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[xgrammar]
261261
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_guided_decoding_with_ngram[llguidance]
262+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_3B::test_auto_dtype
263+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_7B::test_auto_dtype
264+
- accuracy/test_llm_api_pytorch.py::TestStarcoder2_15B::test_auto_dtype
262265
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True]
263266
- test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-mixture_text_image-True]
264267
- test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-0.6-image]

tests/unittest/_torch/modeling/test_modeling_starcoder2.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from tensorrt_llm._torch.metadata import KVCacheParams
1515
from tensorrt_llm._torch.model_config import ModelConfig
1616
from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM
17-
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner
1817
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
1918
from tensorrt_llm.bindings.executor import KvCacheConfig
2019
from tensorrt_llm.mapping import Mapping
@@ -114,7 +113,7 @@ def get_kv_cache_manager(
114113
elif dtype == torch.bfloat16:
115114
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
116115
else:
117-
raise ValueError("Invalid dtype")
116+
raise ValueError(f"Invalid dtype: {dtype}")
118117

119118
mapping = Mapping(world_size=1, tp_size=1, rank=0)
120119
kv_cache_config = KvCacheConfig(
@@ -160,7 +159,7 @@ def test_starcoder2_sanity(self):
160159

161160
input_ids = torch.tensor(
162161
[100, 200, 300, 400, 500, 600, 700, 800],
163-
dtype=torch.int,
162+
dtype=torch.long,
164163
device=device,
165164
)
166165

@@ -188,7 +187,7 @@ def test_starcoder2_sanity(self):
188187

189188
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
190189
attn_metadata = metadata_cls(
191-
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
190+
seq_lens=torch.tensor(sequence_lengths, dtype=torch.long),
192191
num_contexts=len(context_sequence_lengths),
193192
kv_cache_params=KVCacheParams(
194193
use_cache=True,
@@ -302,7 +301,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
302301
# Context phase (no CUDA graphs for prefill)
303302
input_ids = torch.tensor(
304303
[100, 200, 300, 400, 500, 600, 700, 800],
305-
dtype=torch.int32,
304+
dtype=torch.long,
306305
device=device,
307306
)
308307
num_cached_tokens_per_seq = [0]
@@ -312,7 +311,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
312311
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
313312

314313
attn_metadata = metadata_cls(
315-
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
314+
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.long),
316315
num_contexts=1,
317316
kv_cache_params=KVCacheParams(
318317
use_cache=True,
@@ -325,7 +324,7 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
325324
prompt_lens=prompt_lens,
326325
)
327326

328-
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.int32)]
327+
position_ids = [torch.arange(0, input_ids.size(-1), dtype=torch.long)]
329328
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
330329

331330
with torch.inference_mode():
@@ -343,11 +342,11 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
343342
torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4)
344343

345344
# Generation phase (optionally with CUDA graphs)
346-
gen_input_ids = torch.tensor([900], dtype=torch.int32, device=device)
345+
gen_input_ids = torch.tensor([900], dtype=torch.long, device=device)
347346
num_cached_tokens_per_seq = [input_ids.size(-1)]
348347

349348
attn_metadata = metadata_cls(
350-
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
349+
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.long),
351350
num_contexts=0,
352351
kv_cache_params=KVCacheParams(
353352
use_cache=True,
@@ -362,18 +361,17 @@ def test_starcoder2_allclose_to_hf(self, scenario: Scenario) -> None:
362361

363362
gen_position_ids = [
364363
torch.arange(
365-
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.int32
364+
input_ids.size(-1), input_ids.size(-1) + gen_input_ids.size(-1), dtype=torch.long
366365
)
367366
]
368367
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
369368

370369
# Setup CUDA graph runner if requested
371370
graph_runner = None
372371
if use_cuda_graph:
373-
from _torch.helpers import create_mock_engine
372+
from _torch.helpers import create_mock_cuda_graph_runner
374373

375-
mock_engine = create_mock_engine(1)
376-
graph_runner = CUDAGraphRunner(mock_engine)
374+
graph_runner = create_mock_cuda_graph_runner(1)
377375
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
378376

379377
# Run generation phase
@@ -476,7 +474,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
476474
# Encode test prompt
477475
input_ids = torch.tensor(
478476
tokenizer.encode(test_prompt),
479-
dtype=torch.int32,
477+
dtype=torch.long,
480478
device=device,
481479
)
482480

@@ -508,7 +506,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
508506
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
509507

510508
attn_metadata = metadata_cls(
511-
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
509+
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.long),
512510
num_contexts=1,
513511
kv_cache_params=KVCacheParams(
514512
use_cache=True,
@@ -522,7 +520,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
522520
)
523521

524522
position_ids = torch.arange(
525-
0, input_ids.size(-1), dtype=torch.int32, device=device
523+
0, input_ids.size(-1), dtype=torch.long, device=device
526524
).unsqueeze(0)
527525

528526
with torch.inference_mode():
@@ -540,10 +538,10 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
540538

541539
# Generation phase - generate remaining tokens
542540
for step in range(1, max_new_tokens):
543-
gen_input_ids = torch.tensor([next_token_id], dtype=torch.int32, device=device)
541+
gen_input_ids = torch.tensor([next_token_id], dtype=torch.long, device=device)
544542

545543
attn_metadata = metadata_cls(
546-
seq_lens=torch.tensor([1], dtype=torch.int),
544+
seq_lens=torch.tensor([1], dtype=torch.long),
547545
num_contexts=0,
548546
kv_cache_params=KVCacheParams(
549547
use_cache=True,
@@ -557,7 +555,7 @@ def test_starcoder2_generated_tokens_match_hf(self, scenario: Scenario) -> None:
557555
)
558556

559557
gen_position_ids = torch.arange(
560-
num_cached_tokens, num_cached_tokens + 1, dtype=torch.int32, device=device
558+
num_cached_tokens, num_cached_tokens + 1, dtype=torch.long, device=device
561559
).unsqueeze(0)
562560

563561
with torch.inference_mode():

0 commit comments

Comments
 (0)