@@ -4177,14 +4177,16 @@ def test_w4_chunked_prefill(self, kv_cache_dtype, moe_backend, mocker):
41774177 ["CUTLASS" ,
41784178 pytest .param ("TRTLLM" , marks = skip_pre_blackwell ), "TRITON" ],
41794179 ids = ["cutlass" , "trtllm" , "triton" ])
4180- def test_eagle3 (self , moe_backend , one_model , overlap_scheduler , mocker ):
4180+ def test_eagle3_4gpus (self , moe_backend , one_model , overlap_scheduler ,
4181+ mocker ):
41814182 if moe_backend == "TRITON" :
41824183 if not IS_TRITON_KERNELS_AVAILABLE :
41834184 pytest .skip ("Triton kernels are not available" )
41844185
4185- if get_sm_version () == 90 and moe_backend == "CUTLASS" :
4186+ if get_sm_version () == 90 :
41864187 pytest .skip (
4187- "https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue" )
4188+ "https://nvbugs/5636916: Remaining Hopper Eagle Accuracy Issue for only TP=4"
4189+ )
41884190
41894191 MAX_OUTPUT_LEN = 128179
41904192 MAX_INPUT_LEN = 32768
@@ -4247,6 +4249,86 @@ def test_eagle3(self, moe_backend, one_model, overlap_scheduler, mocker):
42474249 sampling_params = sampling_params ,
42484250 extra_evaluator_kwargs = extra_evaluator_kwargs )
42494251
4252+ @pytest .mark .skip_less_device (2 )
4253+ @pytest .mark .timeout (14400 )
4254+ @pytest .mark .parametrize ("overlap_scheduler" , [True , False ],
4255+ ids = ["overlap_scheduler" , "no_overlap_scheduler" ])
4256+ @pytest .mark .parametrize ("one_model" , [True , False ],
4257+ ids = ["one_model" , "two_model" ])
4258+ @pytest .mark .parametrize (
4259+ "moe_backend" ,
4260+ ["CUTLASS" ,
4261+ pytest .param ("TRTLLM" , marks = skip_pre_blackwell ), "TRITON" ],
4262+ ids = ["cutlass" , "trtllm" , "triton" ])
4263+ def test_eagle3_2gpus (self , moe_backend , one_model , overlap_scheduler ,
4264+ mocker ):
4265+ if moe_backend == "TRITON" :
4266+ if not IS_TRITON_KERNELS_AVAILABLE :
4267+ pytest .skip ("Triton kernels are not available" )
4268+
4269+ MAX_OUTPUT_LEN = 128179
4270+ MAX_INPUT_LEN = 32768
4271+
4272+ mocker .patch .object (GSM8K , "MAX_OUTPUT_LEN" , 8192 )
4273+ mocker .patch .dict (GSM8K .EVALUATE_KWARGS ,
4274+ {"scores_filter" : "exact_match,flexible-extract" })
4275+
4276+ mocker .patch .object (GPQADiamond , "MAX_OUTPUT_LEN" , MAX_OUTPUT_LEN )
4277+ mocker .patch .object (GPQADiamond , "MAX_INPUT_LEN" , MAX_INPUT_LEN )
4278+
4279+ # https://nvbugs/5590408: 2-Model overlap scheduling has accuracy issue
4280+ pytorch_config = dict (
4281+ max_batch_size = 8 ,
4282+ disable_overlap_scheduler = not overlap_scheduler ,
4283+ cuda_graph_config = CudaGraphConfig (max_batch_size = 8 ))
4284+ kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.4 ,
4285+ dtype = "auto" )
4286+
4287+ eagle_model_dir = f"{ llm_models_root ()} /gpt_oss/gpt-oss-120b-Eagle3"
4288+ draft_len = 3
4289+ spec_config = EagleDecodingConfig (max_draft_len = draft_len ,
4290+ speculative_model_dir = eagle_model_dir ,
4291+ eagle3_one_model = one_model )
4292+
4293+ max_seq_len = MAX_INPUT_LEN + MAX_OUTPUT_LEN
4294+ llm = LLM (self .MODEL_PATH ,
4295+ tensor_parallel_size = 2 ,
4296+ pipeline_parallel_size = 1 ,
4297+ moe_expert_parallel_size = 1 ,
4298+ kv_cache_config = kv_cache_config ,
4299+ max_seq_len = max_seq_len ,
4300+ speculative_config = spec_config ,
4301+ ** pytorch_config ,
4302+ enable_attention_dp = False ,
4303+ moe_config = MoeConfig (backend = moe_backend ))
4304+
4305+ with llm :
4306+ model_name = "GPT-OSS/120B-MXFP4"
4307+
4308+ # GSM8K
4309+ task = GSM8K (model_name )
4310+ task .evaluate (llm ,
4311+ extra_evaluator_kwargs = self .extra_evaluator_kwargs )
4312+
4313+ # GPQA Medium Reasoning
4314+ task = GPQADiamond (model_name )
4315+
4316+ chat_template_kwargs = dict (reasoning_effort = "medium" )
4317+ extra_evaluator_kwargs = {
4318+ ** self .extra_evaluator_kwargs , "chat_template_kwargs" :
4319+ chat_template_kwargs
4320+ }
4321+
4322+ sampling_params = SamplingParams (
4323+ temperature = 1.0 ,
4324+ top_p = 1.0 ,
4325+ max_tokens = MAX_OUTPUT_LEN ,
4326+ truncate_prompt_tokens = MAX_INPUT_LEN )
4327+
4328+ task .evaluate (llm ,
4329+ sampling_params = sampling_params ,
4330+ extra_evaluator_kwargs = extra_evaluator_kwargs )
4331+
42504332 @pytest .mark .skip_less_device (4 )
42514333 @pytest .mark .skip_device_not_contain (["GB200" ])
42524334 @pytest .mark .parametrize (
0 commit comments