@@ -798,8 +798,12 @@ def test_fp8_block_scales_4gpus_static_eplb(self):
798798 (False , False , True , False ),
799799 (False , False , False , True ),
800800 (True , False , True , True ), (True , True , True , True )])
801+ @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
802+ @parametrize_with_ids ("moe_backend" , ["CUTLASS" , "TRTLLM" ])
801803 def test_nvfp4 (self , fp8kv , attention_dp , cuda_graph , overlap_scheduler ,
802- torch_compile ):
804+ torch_compile , mtp_nextn , moe_backend ):
805+ if torch_compile and mtp_nextn > 0 :
806+ pytest .skip ("https://nvbugs/5252313" )
803807 if torch_compile and attention_dp :
804808 pytest .skip ("https://nvbugs/5252559" )
805809 kv_cache_config = KvCacheConfig (free_gpu_memory_fraction = 0.9 )
@@ -810,18 +814,24 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
810814 disable_overlap_scheduler = not overlap_scheduler ,
811815 use_cuda_graph = cuda_graph ,
812816 torch_compile_config = torch_compile_config ,
817+ moe_backend = moe_backend ,
813818 )
819+ mtp_config = None
820+ if mtp_nextn > 0 :
821+ mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
822+
814823 quant_config = QuantConfig ()
815824 quant_config .quant_algo = QuantAlgo .NVFP4
816825 if fp8kv :
817826 quant_config .kv_cache_quant_algo = QuantAlgo .FP8
818827 pytorch_config ["kv_cache_dtype" ] = "fp8"
819828
820- llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/nvfp4_moe_only " ,
829+ llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/nvfp4_moe_only_mtp " ,
821830 kv_cache_config = kv_cache_config ,
822831 ** pytorch_config ,
823832 quant_config = quant_config ,
824- enable_attention_dp = attention_dp )
833+ enable_attention_dp = attention_dp ,
834+ speculative_config = mtp_config )
825835
826836 assert llm .args .quant_config .quant_algo == QuantAlgo .NVFP4
827837 if fp8kv :
@@ -850,9 +860,13 @@ def test_nvfp4(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler,
850860 @pytest .mark .parametrize ("tp_size,pp_size,ep_size" , [(4 , 1 , 1 ), (4 , 1 , 4 ),
851861 (2 , 2 , 1 ), (1 , 4 , 1 )],
852862 ids = ["tp4" , "ep4" , "tp2pp2" , "pp4" ])
863+ @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
864+ @parametrize_with_ids ("moe_backend" , ["CUTLASS" , "TRTLLM" ])
853865 def test_nvfp4_4gpus (self , fp8kv , attention_dp , cuda_graph ,
854866 overlap_scheduler , tp_size , pp_size , ep_size ,
855- torch_compile ):
867+ torch_compile , mtp_nextn , moe_backend ):
868+ if torch_compile and mtp_nextn > 0 :
869+ pytest .skip ("https://nvbugs/5252313" )
856870 if torch_compile and attention_dp :
857871 pytest .skip ("https://nvbugs/5252559" )
858872 if torch_compile and pp_size > 1 :
@@ -867,22 +881,28 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
867881 disable_overlap_scheduler = not overlap_scheduler ,
868882 use_cuda_graph = cuda_graph ,
869883 torch_compile_config = torch_compile_config ,
884+ moe_backend = moe_backend ,
870885 )
871886
887+ mtp_config = None
888+ if mtp_nextn > 0 :
889+ mtp_config = MTPDecodingConfig (num_nextn_predict_layers = mtp_nextn )
890+
872891 quant_config = QuantConfig ()
873892 quant_config .quant_algo = QuantAlgo .NVFP4
874893 if fp8kv :
875894 quant_config .kv_cache_quant_algo = QuantAlgo .FP8
876895 pytorch_config ["kv_cache_dtype" ] = "fp8"
877896
878- llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/nvfp4_moe_only " ,
897+ llm = LLM (f"{ llm_models_root ()} /DeepSeek-V3-Lite/nvfp4_moe_only_mtp " ,
879898 tensor_parallel_size = tp_size ,
880899 pipeline_parallel_size = pp_size ,
881900 moe_expert_parallel_size = ep_size ,
882901 kv_cache_config = kv_cache_config ,
883902 ** pytorch_config ,
884903 quant_config = quant_config ,
885- enable_attention_dp = attention_dp )
904+ enable_attention_dp = attention_dp ,
905+ speculative_config = mtp_config )
886906
887907 assert llm .args .quant_config .quant_algo == QuantAlgo .NVFP4
888908 if fp8kv :
0 commit comments