@@ -3650,6 +3650,60 @@ def test_w4a16(self, kv_cache_dtype, tp_size, pp_size, ep_size,
36503650 task .evaluate (llm ,
36513651 extra_evaluator_kwargs = self .extra_evaluator_kwargs )
36523652
3653+ @pytest .mark .skip_less_device (4 )
3654+ @pytest .mark .parametrize ("kv_cache_dtype" , ["auto" ])
3655+ @pytest .mark .parametrize (
3656+ "tp_size,pp_size,ep_size,attention_dp,cuda_graph,overlap_scheduler" , [
3657+ (4 , 1 , 4 , False , True , True ),
3658+ ],
3659+ ids = ["tep4" ])
3660+ @pytest .mark .parametrize (
3661+ "moe_backend" ,
3662+ ["CUTLASS" ,
3663+ pytest .param ("TRTLLM" , marks = skip_pre_blackwell ), "TRITON" ],
3664+ ids = ["cutlass" , "trtllm" , "triton" ])
3665+ def test_w4a16_eagle3 (self , kv_cache_dtype , tp_size , pp_size , ep_size ,
3666+ attention_dp , cuda_graph , overlap_scheduler ,
3667+ moe_backend , monkeypatch , mocker ):
3668+ mocker .patch .object (GSM8K , "MAX_OUTPUT_LEN" , 8192 )
3669+ mocker .patch .dict (GSM8K .EVALUATE_KWARGS ,
3670+ {"scores_filter" : "exact_match,flexible-extract" })
3671+ if moe_backend == "TRITON" and not IS_TRITON_KERNELS_AVAILABLE :
3672+ pytest .skip ("Triton kernels are not available" )
3673+ monkeypatch .setenv ("OVERRIDE_QUANT_ALGO" , "W4A16_MXFP4" )
3674+
3675+ cuda_graph_config = CudaGraphConfig (enable_padding = True ,
3676+ max_batch_size = 8 )
3677+
3678+ pytorch_config = dict (
3679+ max_batch_size = 8 ,
3680+ disable_overlap_scheduler = not overlap_scheduler ,
3681+ cuda_graph_config = CudaGraphConfig () if cuda_graph else None )
3682+
3683+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.5 ,
3684+ dtype = kv_cache_dtype )
3685+ spec_config = EagleDecodingConfig (
3686+ max_draft_len = 3 ,
3687+ speculative_model_dir =
3688+ f"{ llm_models_root ()} /gpt_oss/gpt-oss-120b-Eagle3/" ,
3689+ eagle3_one_model = True )
3690+
3691+ llm = LLM (self .MODEL_PATH ,
3692+ tensor_parallel_size = tp_size ,
3693+ pipeline_parallel_size = pp_size ,
3694+ moe_expert_parallel_size = ep_size ,
3695+ kv_cache_config = kv_cache_config ,
3696+ ** pytorch_config ,
3697+ enable_attention_dp = attention_dp ,
3698+ moe_config = MoeConfig (backend = moe_backend ),
3699+ speculative_config = spec_config )
3700+
3701+ with llm :
3702+ model_name = "GPT-OSS/BF16"
3703+ task = GSM8K (model_name )
3704+ task .evaluate (llm ,
3705+ extra_evaluator_kwargs = self .extra_evaluator_kwargs )
3706+
36533707 @pytest .mark .skip_less_device (2 )
36543708 @pytest .mark .parametrize (
36553709 "kv_cache_dtype" ,
0 commit comments