@@ -3598,6 +3598,59 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
35983598 task .evaluate (llm ,
35993599 extra_evaluator_kwargs = self .extra_evaluator_kwargs )
36003600
3601+ @pytest .mark .skip_less_device (4 )
3602+ @pytest .mark .parametrize ("kv_cache_dtype" , ["auto" ])
3603+ @pytest .mark .parametrize (
3604+ "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler" , [
3605+ (4 , 1 , 4 , False , True , True ),
3606+ ],
3607+ ids = ["tep4" ])
3608+ @pytest .mark .parametrize (
3609+ "moe_backend" ,
3610+ ["triton" , "cutlass" ,
3611+ pytest .param ("trtllm" , marks = skip_pre_blackwell )])
3612+ def test_w4a16_eagle3 (self , kv_cache_dtype , tp_size , pp_size , ep_size ,
3613+ attention_dp , cuda_graph , overlap_scheduler ,
3614+ moe_backend , monkeypatch , mocker ):
3615+ mocker .patch .object (GSM8K , "MAX_OUTPUT_LEN" , 8192 )
3616+ mocker .patch .dict (GSM8K .EVALUATE_KWARGS ,
3617+ {"scores_filter" : "exact_match,flexible-extract" })
3618+ if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE :
3619+ pytest .skip ("Triton kernels are not available" )
3620+ monkeypatch .setenv ("OVERRIDE_QUANT_ALGO" , "W4A16_MXFP4" )
3621+
3622+ cuda_graph_config = CudaGraphConfig (enable_padding = True ,
3623+ max_batch_size = 8 )
3624+
3625+ pytorch_config = dict (
3626+ max_batch_size = 8 ,
3627+ disable_overlap_scheduler = not overlap_scheduler ,
3628+ cuda_graph_config = CudaGraphConfig () if cuda_graph else None )
3629+
3630+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.5 ,
3631+ dtype = kv_cache_dtype )
3632+ spec_config = EagleDecodingConfig (
3633+ max_draft_len = 3 ,
3634+ speculative_model_dir =
3635+ f"{ llm_models_root ()} /gpt_oss/gpt-oss-120b-Eagle3/" ,
3636+ eagle3_one_model = True )
3637+
3638+ llm = LLM (self .MODEL_PATH ,
3639+ tensor_parallel_size = tp_size ,
3640+ pipeline_parallel_size = pp_size ,
3641+ moe_expert_parallel_size = ep_size ,
3642+ kv_cache_config = kv_cache_config ,
3643+ ** pytorch_config ,
3644+ enable_attention_dp = attention_dp ,
3645+ moe_config = MoeConfig (backend = moe_backend ),
3646+ speculative_config = spec_config )
3647+
3648+ with llm :
3649+ model_name = "GPT-OSS/BF16"
3650+ task = GSM8K (model_name )
3651+ task .evaluate (llm ,
3652+ extra_evaluator_kwargs = self .extra_evaluator_kwargs )
3653+
36013654 @pytest .mark .skip_less_device (2 )
36023655 @pytest .mark .parametrize (
36033656 "kv_cache_dtype" ,
0 commit comments