@@ -850,8 +850,12 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
850850 task = GSM8K (self .MODEL_NAME )
851851 task .evaluate (llm )
852852
853- @pytest .mark .skip_less_device (4 )
854- def test_auto_dtype_with_helix (self ):
853+ @pytest .mark .skip_less_device (8 )
854+ @pytest .mark .parametrize ("gen_pp,gen_tp,gen_cp" , [(1 , 1 , 4 ), (1 , 2 , 2 ),
855+ (2 , 1 , 2 )],
856+ ids = ["pp1tp1cp4" , "pp1tp2cp2" , "pp2tp1cp2" ])
857+ def test_auto_dtype_with_helix (self , gen_pp , gen_tp , gen_cp ):
858+ gen_ep = gen_tp * gen_cp
855859 kv_cache_config = {
856860 "free_gpu_memory_fraction" : 0.5 ,
857861 "enable_block_reuse" : False ,
@@ -860,7 +864,7 @@ def test_auto_dtype_with_helix(self):
860864 }
861865 ctx_server_config = {
862866 "pipeline_parallel_size" : 1 ,
863- "tensor_parallel_size" : 2 ,
867+ "tensor_parallel_size" : 4 ,
864868 "context_parallel_size" : 1 ,
865869 "disable_overlap_scheduler" : True ,
866870 "kv_cache_config" : kv_cache_config ,
@@ -871,9 +875,10 @@ def test_auto_dtype_with_helix(self):
871875 },
872876 }
873877 gen_server_config = {
874- "tensor_parallel_size" : 1 ,
875- "pipeline_parallel_size" : 1 ,
876- "context_parallel_size" : 2 ,
878+ "tensor_parallel_size" : gen_tp ,
879+ "pipeline_parallel_size" : gen_pp ,
880+ "context_parallel_size" : gen_cp ,
881+ "moe_expert_parallel_size" : gen_ep ,
877882 "cp_config" : {
878883 "cp_type" : "HELIX" ,
879884 "tokens_per_block" : 32
@@ -904,8 +909,8 @@ def test_auto_dtype_with_helix(self):
904909 self .MODEL_PATH ) as llm :
905910 task = MMLU (self .MODEL_NAME )
906911 task .evaluate (llm )
907- task = GSM8K (self .MODEL_NAME )
908- task .evaluate (llm )
912+ # task = GSM8K(self.MODEL_NAME)
913+ # task.evaluate(llm)
909914
910915 @pytest .mark .skip_less_device (2 )
911916 @pytest .mark .skip_less_device_memory (60000 )
0 commit comments