@@ -3595,6 +3595,59 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
35953595 task .evaluate (llm ,
35963596 extra_evaluator_kwargs = self .extra_evaluator_kwargs )
35973597
3598+ @pytest .mark .skip_less_device (4 )
3599+ @pytest .mark .parametrize ("kv_cache_dtype" , ["auto" ])
3600+ @pytest .mark .parametrize (
3601+ "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler" , [
3602+ (4 , 1 , 4 , False , True , True ),
3603+ ],
3604+ ids = ["tep4" ])
3605+ @pytest .mark .parametrize (
3606+ "moe_backend" ,
3607+ ["triton" , "cutlass" ,
3608+ pytest .param ("trtllm" , marks = skip_pre_blackwell )])
3609+ def test_w4a16_eagle3 (self , kv_cache_dtype , tp_size , pp_size , ep_size ,
3610+ attention_dp , cuda_graph , overlap_scheduler ,
3611+ moe_backend , monkeypatch , mocker ):
3612+ mocker .patch .object (GSM8K , "MAX_OUTPUT_LEN" , 8192 )
3613+ mocker .patch .dict (GSM8K .EVALUATE_KWARGS ,
3614+ {"scores_filter" : "exact_match,flexible-extract" })
3615+ if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE :
3616+ pytest .skip ("Triton kernels are not available" )
3617+ monkeypatch .setenv ("OVERRIDE_QUANT_ALGO" , "W4A16_MXFP4" )
3618+
3619+ cuda_graph_config = CudaGraphConfig (enable_padding = True ,
3620+ max_batch_size = 8 )
3621+
3622+ pytorch_config = dict (
3623+ max_batch_size = 8 ,
3624+ disable_overlap_scheduler = not overlap_scheduler ,
3625+ cuda_graph_config = CudaGraphConfig () if cuda_graph else None )
3626+
3627+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.5 ,
3628+ dtype = kv_cache_dtype )
3629+ spec_config = EagleDecodingConfig (
3630+ max_draft_len = 3 ,
3631+ speculative_model_dir =
3632+ f"{ llm_models_root ()} /gpt_oss/gpt-oss-120b-Eagle3/" ,
3633+ eagle3_one_model = True )
3634+
3635+ llm = LLM (self .MODEL_PATH ,
3636+ tensor_parallel_size = tp_size ,
3637+ pipeline_parallel_size = pp_size ,
3638+ moe_expert_parallel_size = ep_size ,
3639+ kv_cache_config = kv_cache_config ,
3640+ ** pytorch_config ,
3641+ enable_attention_dp = attention_dp ,
3642+ moe_config = MoeConfig (backend = moe_backend ),
3643+ speculative_config = spec_config )
3644+
3645+ with llm :
3646+ model_name = "GPT-OSS/BF16"
3647+ task = GSM8K (model_name )
3648+ task .evaluate (llm ,
3649+ extra_evaluator_kwargs = self .extra_evaluator_kwargs )
3650+
35983651 @pytest .mark .skip_less_device (2 )
35993652 @pytest .mark .parametrize (
36003653 "kv_cache_dtype" ,
0 commit comments