@@ -832,8 +832,12 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
832832 task = GSM8K (self .MODEL_NAME )
833833 task .evaluate (llm )
834834
835- @pytest .mark .skip_less_device (4 )
836- def test_auto_dtype_with_helix (self ):
835+ @pytest .mark .skip_less_device (8 )
836+ @pytest .mark .parametrize ("gen_pp,gen_tp,gen_cp" , [(1 , 1 , 4 ), (1 , 2 , 2 ),
837+ (2 , 1 , 2 )],
838+ ids = ["pp1tp1cp4" , "pp1tp2cp2" , "pp2tp1cp2" ])
839+ def test_auto_dtype_with_helix (self , gen_pp , gen_tp , gen_cp ):
840+ gen_ep = gen_tp * gen_cp
837841 kv_cache_config = {
838842 "free_gpu_memory_fraction" : 0.5 ,
839843 "enable_block_reuse" : False ,
@@ -842,7 +846,7 @@ def test_auto_dtype_with_helix(self):
842846 }
843847 ctx_server_config = {
844848 "pipeline_parallel_size" : 1 ,
845- "tensor_parallel_size" : 2 ,
849+ "tensor_parallel_size" : 4 ,
846850 "context_parallel_size" : 1 ,
847851 "disable_overlap_scheduler" : True ,
848852 "kv_cache_config" : kv_cache_config ,
@@ -853,9 +857,10 @@ def test_auto_dtype_with_helix(self):
853857 },
854858 }
855859 gen_server_config = {
856- "tensor_parallel_size" : 1 ,
857- "pipeline_parallel_size" : 1 ,
858- "context_parallel_size" : 2 ,
860+ "tensor_parallel_size" : gen_tp ,
861+ "pipeline_parallel_size" : gen_pp ,
862+ "context_parallel_size" : gen_cp ,
863+ "moe_expert_parallel_size" : gen_ep ,
859864 "cp_config" : {
860865 "cp_type" : "HELIX" ,
861866 "tokens_per_block" : 32
@@ -889,8 +894,8 @@ def test_auto_dtype_with_helix(self):
889894 self .MODEL_PATH ) as llm :
890895 task = MMLU (self .MODEL_NAME )
891896 task .evaluate (llm )
892- task = GSM8K (self .MODEL_NAME )
893- task .evaluate (llm )
897+ # task = GSM8K(self.MODEL_NAME)
898+ # task.evaluate(llm)
894899
895900 @pytest .mark .skip_less_device (2 )
896901 @pytest .mark .skip_less_device_memory (60000 )
0 commit comments