@@ -159,17 +159,17 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
159159 "--backend" ,
160160 "pytorch" ,
161161 ]
162- gen_tp , gen_pp = gen_server_config .get (
163- "tensor_parallel_size" ,
164- tensor_parallel_size ), gen_server_config . get ( "pipeline_parallel_size" ,
165- 1 )
166- ctx_tp , ctx_pp = ctx_server_config .get (
167- "tensor_parallel_size" ,
168- tensor_parallel_size ), ctx_server_config . get ( "pipeline_parallel_size" ,
169- 1 )
170-
171- ctx_total_gpus = ctx_tp * ctx_pp
172- gen_total_gpus = gen_tp * gen_pp
162+ gen_tp , gen_pp , gen_cp = gen_server_config .get (
163+ "tensor_parallel_size" , tensor_parallel_size ), gen_server_config . get (
164+ "pipeline_parallel_size" ,
165+ 1 ), gen_server_config . get ( "context_parallel_size" , 1 )
166+ ctx_tp , ctx_pp , ctx_cp = ctx_server_config .get (
167+ "tensor_parallel_size" , tensor_parallel_size ), ctx_server_config . get (
168+ "pipeline_parallel_size" ,
169+ 1 ), ctx_server_config . get ( "context_parallel_size" , 1 )
170+
171+ ctx_total_gpus = ctx_tp * ctx_pp * ctx_cp
172+ gen_total_gpus = gen_tp * gen_pp * gen_cp
173173
174174 ctx_urls = disaggregated_server_config ["context_servers" ]["urls" ]
175175 gen_urls = disaggregated_server_config ["generation_servers" ]["urls" ]
@@ -196,7 +196,7 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
196196 ctx_server_args = ctx_args + [
197197 "--port" ,
198198 str (port ), "--extra_llm_api_options" , ctx_server_config_path ,
199- f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } "
199+ f"--tp_size={ ctx_tp } " , f"--pp_size={ ctx_pp } " , f"--cp_size= { ctx_cp } "
200200 ]
201201 if "max_num_tokens" in ctx_server_config :
202202 ctx_server_args .append (
@@ -219,7 +219,7 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
219219 gen_server_args = gen_args + [
220220 "--port" ,
221221 str (port ), "--extra_llm_api_options" , gen_server_config_path ,
222- f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } "
222+ f"--tp_size={ gen_tp } " , f"--pp_size={ gen_pp } " , f"--cp_size= { gen_cp } "
223223 ]
224224 if "max_num_tokens" in gen_server_config :
225225 gen_server_args .append (
@@ -853,6 +853,65 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
853853 task = GSM8K (self .MODEL_NAME )
854854 task .evaluate (llm )
855855
856+ @pytest .mark .skip_less_device (4 )
857+ def test_auto_dtype_with_helix (self ):
858+ kv_cache_config = {
859+ "free_gpu_memory_fraction" : 0.5 ,
860+ "enable_block_reuse" : False ,
861+ "enable_partial_reuse" : False ,
862+ "tokens_per_block" : 32 ,
863+ }
864+ ctx_server_config = {
865+ "pipeline_parallel_size" : 1 ,
866+ "tensor_parallel_size" : 2 ,
867+ "context_parallel_size" : 1 ,
868+ "max_batch_size" : 8 ,
869+ "disable_overlap_scheduler" : True ,
870+ "kv_cache_config" : kv_cache_config ,
871+ "enable_chunked_prefill" : False ,
872+ "cuda_graph_config" : None ,
873+ "cache_transceiver_config" : {
874+ "backend" : "UCX"
875+ },
876+ }
877+ gen_server_config = {
878+ "tensor_parallel_size" : 1 ,
879+ "pipeline_parallel_size" : 1 ,
880+ "context_parallel_size" : 2 ,
881+ "cp_config" : {
882+ "cp_type" : "HELIX" ,
883+ "tokens_per_block" : 32
884+ },
885+ "max_batch_size" : 8 ,
886+ "disable_overlap_scheduler" : True ,
887+ "kv_cache_config" : kv_cache_config ,
888+ "enable_chunked_prefill" : False ,
889+ "cuda_graph_config" : None ,
890+ "cache_transceiver_config" : {
891+ "backend" : "UCX"
892+ },
893+ }
894+ disaggregated_server_config = {
895+ "hostname" : "localhost" ,
896+ "port" : 8000 ,
897+ "backend" : "pytorch" ,
898+ "context_servers" : {
899+ "num_instances" : 1 ,
900+ "urls" : ["localhost:8001" ]
901+ },
902+ "generation_servers" : {
903+ "num_instances" : 1 ,
904+ "urls" : ["localhost:8002" ]
905+ }
906+ }
907+ with launch_disaggregated_llm (disaggregated_server_config ,
908+ ctx_server_config , gen_server_config ,
909+ self .MODEL_PATH ) as llm :
910+ task = MMLU (self .MODEL_NAME )
911+ task .evaluate (llm )
912+ task = GSM8K (self .MODEL_NAME )
913+ task .evaluate (llm , extra_acc_spec = "helix_with_bs8" )
914+
856915 @pytest .mark .skip_less_device (2 )
857916 @pytest .mark .skip_less_device_memory (60000 )
858917 @parametrize_with_ids ("mtp_nextn" , [0 , 2 ])
0 commit comments