@@ -571,16 +571,20 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
571571
572572 @pytest .mark .skip_less_device (8 )
573573 @skip_pre_blackwell
574- @parametrize_with_ids ("overlap_scheduler" , [False , True ])
575- @parametrize_with_ids ("cuda_graph" , [False , True ])
576- @parametrize_with_ids ("attention_dp" , [False , True ])
577- @parametrize_with_ids ("fp8kv" , [False , True ])
578- @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
579- @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(8 , 1 , 1 ), (8 , 1 , 4 ),
580- (8 , 1 , 8 )],
581- ids = ["tp8" , "tp8ep4" , "tp8ep8" ])
574+ @pytest .mark .parametrize (
575+ "tp_size,pp_size,ep_size,mtp_nextn,fp8kv,attention_dp,cuda_graph,overlap_scheduler,batch_size,moe_backend" ,
576+ [
577+ (8 , 1 , 4 , 3 , False , False , True , True , 1 , "CUTLASS" ),
578+ #TODO: enable mtp after bug fix
579+ (8 , 1 , 4 , 0 , False , False , True , True , 1 , "TRTLLM" ),
580+ (8 , 1 , 8 , 0 , True , True , True , True , 24 , "CUTLASS" ),
581+ (8 , 1 , 1 , 0 , True , True , True , True , 24 , "CUTLASS" ),
582+ ],
583+ ids = ["latency" , "latency_trtllmgen" , "throughput" , "throughput_tp8" ])
582584 def test_nvfp4_8gpus (self , tp_size , pp_size , ep_size , mtp_nextn , fp8kv ,
583- attention_dp , cuda_graph , overlap_scheduler ):
585+ attention_dp , cuda_graph , overlap_scheduler ,
586+ batch_size , moe_backend ):
587+
584588 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.4 )
585589 pytorch_config = PyTorchConfig (
586590 enable_overlap_scheduler = overlap_scheduler ,
@@ -596,14 +600,16 @@ def test_nvfp4_8gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv,
596600 if mtp_nextn > 0 :
597601 mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
598602 llm = LLM (f"{ llm_models_root ()} /DeepSeek-R1/DeepSeek-R1-FP4" ,
603+ batch_size = batch_size ,
599604 tensor_parallel_size = tp_size ,
600605 pipeline_parallel_size = pp_size ,
601606 moe_expert_parallel_size = ep_size ,
602607 kv_cache_config = kv_cache_config ,
603608 pytorch_backend_config = pytorch_config ,
604609 quant_config = quant_config ,
605610 enable_attention_dp = attention_dp ,
606- speculative_config = mtp_config )
611+ speculative_config = mtp_config ,
612+ moe_backend = moe_backend )
607613 assert llm .args .quant_config .quant_algo == QuantAlgo .NVFP4
608614 if fp8kv :
609615 assert llm .args .quant_config .kv_cache_quant_algo == QuantAlgo .FP8
0 commit comments