Skip to content

Commit c738411

Browse files
committed
add eagle3 gpt-oss test
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
1 parent 13f87aa commit c738411

File tree

6 files changed

+65
-3
lines changed

6 files changed

+65
-3
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,10 +1071,10 @@ def update_spec_dec_param(
10711071
spec_decoding_packed_mask = None
10721072
spec_decoding_generation_lengths = None
10731073
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
1074-
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
1075-
) < 100
1074+
self.is_spec_decoding_enabled = is_spec_decoding_enabled and (
1075+
get_sm_version() < 100 or get_sm_version() == 120)
10761076

1077-
if get_sm_version() >= 100:
1077+
if get_sm_version() >= 100 and get_sm_version() != 120:
10781078
if is_spec_dec_tree or is_spec_dec_dynamic_tree:
10791079
assert not is_spec_dec_tree, "Spec-dec tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec tree."
10801080
assert not is_spec_dec_dynamic_tree, "Spec-dec dynamic tree is not supported on this machine. Please use a pre-Blackwell machine for a spec-dec dynamic tree."

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ GPT-OSS/BF16:
217217
- accuracy: 90.3
218218
- kv_cache_quant_algo: FP8
219219
accuracy: 90.3
220+
- quant_algo: W4A16_MXFP4
221+
accuracy: 90.3
222+
- quant_algo: W4A16_MXFP4
223+
spec_dec_algo: Eagle
224+
accuracy: 90.3
220225
GPT-OSS/120B-MXFP4:
221226
- accuracy: 90.3
222227
- quant_algo: W4A8_MXFP4_MXFP8

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3650,6 +3650,60 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
36503650
task.evaluate(llm,
36513651
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
36523652

3653+
@pytest.mark.skip_less_device(4)
3654+
@pytest.mark.parametrize("kv_cache_dtype", ["auto"])
3655+
@pytest.mark.parametrize(
3656+
"tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler", [
3657+
(4, 1, 4, False, True, True),
3658+
],
3659+
ids=["tep4"])
3660+
@pytest.mark.parametrize(
3661+
"moe_backend",
3662+
["CUTLASS",
3663+
pytest.param("TRTLLM", marks=skip_pre_blackwell), "TRITON"],
3664+
ids=["cutlass", "trtllm", "triton"])
3665+
def test_w4a16_eagle3(self, kv_cache_dtype, tp_size, pp_size, ep_size,
3666+
attention_dp, cuda_graph, overlap_scheduler,
3667+
moe_backend, monkeypatch, mocker):
3668+
mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", 8192)
3669+
mocker.patch.dict(GSM8K.EVALUATE_KWARGS,
3670+
{"scores_filter": "exact_match,flexible-extract"})
3671+
if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE:
3672+
pytest.skip("Triton kernels are not available")
3673+
monkeypatch.setenv("OVERRIDE_QUANT_ALGO", "W4A16_MXFP4")
3674+
3675+
cuda_graph_config = CudaGraphConfig(enable_padding=True,
3676+
max_batch_size=8)
3677+
3678+
pytorch_config = dict(
3679+
max_batch_size=8,
3680+
disable_overlap_scheduler=not overlap_scheduler,
3681+
cuda_graph_config=CudaGraphConfig() if cuda_graph else None)
3682+
3683+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5,
3684+
dtype=kv_cache_dtype)
3685+
spec_config = EagleDecodingConfig(
3686+
max_draft_len=3,
3687+
speculative_model_dir=
3688+
f"{llm_models_root()}/gpt_oss/gpt-oss-120b-Eagle3/",
3689+
eagle3_one_model=True)
3690+
3691+
llm = LLM(self.MODEL_PATH,
3692+
tensor_parallel_size=tp_size,
3693+
pipeline_parallel_size=pp_size,
3694+
moe_expert_parallel_size=ep_size,
3695+
kv_cache_config=kv_cache_config,
3696+
**pytorch_config,
3697+
enable_attention_dp=attention_dp,
3698+
moe_config=MoeConfig(backend=moe_backend),
3699+
speculative_config=spec_config)
3700+
3701+
with llm:
3702+
model_name = "GPT-OSS/BF16"
3703+
task = GSM8K(model_name)
3704+
task.evaluate(llm,
3705+
extra_evaluator_kwargs=self.extra_evaluator_kwargs)
3706+
36533707
@pytest.mark.skip_less_device(2)
36543708
@pytest.mark.parametrize(
36553709
"kv_cache_dtype",

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ l0_dgx_b200:
4949
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]
5050
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
5151
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-fp8]
52+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[trtllm-tep4-auto]
5253
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-bf16]
5354
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5455
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ l0_dgx_h100:
182182
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
183183
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-triton-auto]
184184
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16[dp4-auto]
185+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[triton-tep4-auto]
185186
- condition:
186187
ranges:
187188
system_gpu_count:

tests/integration/test_lists/test-db/l0_rtx_pro_6000.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,5 +107,6 @@ l0_rtx_pro_6000:
107107
# - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False] # failed
108108
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
109109
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
110+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4a16_eagle3[cutlass-tep4-auto]
110111
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp8-multimodals/Phi-4-multimodal-instruct-FP8]
111112
- test_e2e.py::test_ptp_quickstart_multimodal_2gpu[phi4-multimodal-instruct-fp4-multimodals/Phi-4-multimodal-instruct-FP4]

0 commit comments

Comments
 (0)