Skip to content

Commit ff66d85

Browse files
committed
add eagle3 gpt-oss test
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
1 parent ac12cd9 commit ff66d85

File tree

6 files changed

+64
-3
lines changed

6 files changed

+64
-3
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,10 @@ def update_spec_dec_param(
10711071
spec_decoding_packed_mask = None
10721072
spec_decoding_generation_lengths = None
10731073
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
1074-
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
1075-
) < 100
1074+
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
1075+
get_sm_version() < 100 or get_sm_version() == 120)
10761076

1077-
if get_sm_version() >= 100:
1077+
if get_sm_version() >= 100 and get_sm_version() != 120:
10781078
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
10791079
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
10801080
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ GPT-OSS/BF16:
212212
- accuracy: 90.3
213213
- kv_cache_quant_algo: FP8
214214
accuracy: 90.3
215+
- quant_algo: W4A16_MXFP4
216+
accuracy: 90.3
217+
- quant_algo: W4A16_MXFP4
218+
spec_dec_algo: Eagle
219+
accuracy: 90.3
215220
GPT-OSS/120B-MXFP4:
216221
- accuracy: 90.3
217222
- quant_algo: W4A8_MXFP4_MXFP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3598,6 +3598,59 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
35983598
task.evaluate(llm,
35993599
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
36003600

3601+
@pytest.mark.skip_less_device(4)
3602+
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
3603+
@pytest.mark.parametrize(
3604+
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
3605+
(4, 1, 4, False, True, True),
3606+
],
3607+
ids=["tep4"])
3608+
@pytest.mark.parametrize(
3609+
"moe_backend",
3610+
["triton", "cutlass",
3611+
pytest.param("trtllm", marks=skip_pre_blackwell)])
3612+
def test_w4a16_eagle3(self, kv_cache_dtype, tp_size, pp_size, ep_size,
3613+
attention_dp, cuda_graph, overlap_scheduler,
3614+
moe_backend, monkeypatch, mocker):
3615+
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
3616+
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
3617+
{"scores_filter": "exact_match,flexible-extract"})
3618+
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
3619+
pytest.skip("Triton kernels are not available")
3620+
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
3621+
3622+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
3623+
max_batch_size=8)
3624+
3625+
pytorch_config = dict(
3626+
max_batch_size=8,
3627+
disable_overlap_scheduler=not overlap_scheduler,
3628+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
3629+
3630+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
3631+
dtype=kv_cache_dtype)
3632+
spec_config = EagleDecodingConfig(
3633+
max_draft_len=3,
3634+
speculative_model_dir=
3635+
f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3/",
3636+
eagle3_one_model=True)
3637+
3638+
llm = LLM(self.MODEL_PATH,
3639+
tensor_parallel_size=tp_size,
3640+
pipeline_parallel_size=pp_size,
3641+
moe_expert_parallel_size=ep_size,
3642+
kv_cache_config=kv_cache_config,
3643+
**pytorch_config,
3644+
enable_attention_dp=attention_dp,
3645+
moe_config=MoeConfig(backend=moe_backend),
3646+
speculative_config=spec_config)
3647+
3648+
with llm:
3649+
model_name = "GPT-OSS/BF16"
3650+
task = GSM8K(model_name)
3651+
task.evaluate(llm,
3652+
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
3653+
36013654
@pytest.mark.skip_less_device(2)
36023655
@pytest.mark.parametrize(
36033656
"kv_cache_dtype",

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ l0_dgx_b200:
4949
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]
5050
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
5151
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8]
52+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[trtllm-tep4-auto]
5253
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
5354
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5455
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ l0_dgx_h100:
174174
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
175175
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
176176
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
177+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[triton-tep4-auto]
177178
- condition:
178179
ranges:
179180
system_gpu_count:

tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,6 @@ l0_rtx_pro_6000:
107107
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] # failed
108108
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
109109
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
110+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[cutlass-tep4-auto]
110111
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp8-multimodals/Phi-4-multimodal-instruct-FP8]
111112
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4]

0 commit comments

Comments
 (0)