Skip to content

Commit be48cdf

Browse files
authored
[TRTLLM-9466][test] Evaluate helix parallelism with DSV3 Lite (#9597)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 1a46bb0 commit be48cdf

File tree

3 files changed

+90
-15
lines changed

3 files changed

+90
-15
lines changed

tensorrt_llm/mapping.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,19 @@ def __init__(
6262
if moe_cluster_size == -1:
6363
moe_cluster_size = 1
6464

65-
cp_type = CpType.ULYSSES if cp_config is None else cp_config.get(
66-
"cp_type", CpType.ULYSSES)
65+
# Set default cp_type to ULYSSES.
66+
cp_type = CpType.ULYSSES
67+
68+
# Convert cp_type to CpType enum if it is a string.
69+
if cp_config is not None:
70+
if "cp_type" in cp_config and isinstance(cp_config["cp_type"], str):
71+
try:
72+
cp_config["cp_type"] = CpType[cp_config["cp_type"].upper()]
73+
except KeyError:
74+
raise ValueError(f"Invalid cp_type: {cp_config['cp_type']}. " \
75+
f"Must be one of: {', '.join([t.name for t in CpType])}")
76+
cp_type = cp_config.get("cp_type", CpType.ULYSSES)
77+
6778
moe_world_size = tp_size if cp_type == CpType.ULYSSES else tp_size * cp_size
6879

6980
if moe_tp_size == -1 and moe_ep_size == -1:

tests/integration/defs/accuracy/references/gsm8k.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ deepseek-ai/DeepSeek-V3-Lite:
7070
kv_cache_quant_algo: FP8
7171
spec_dec_algo: MTP
7272
accuracy: 64.14
73+
# https://nvbugs/5637012: Currently, BS>1 has accuracy issues with helix for GSM8K.
74+
# BS=1 has expected accuracy but will be too slow for CI testing. So, adding this
75+
# accuracy spec while we investigate the issue.
76+
- extra_acc_spec: helix_with_bs8
77+
accuracy: 50.0
7378
deepseek-ai/DeepSeek-R1:
7479
- quant_algo: NVFP4
7580
accuracy: 95.42

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -159,17 +159,17 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
159159
"--backend",
160160
"pytorch",
161161
]
162-
gen_tp, gen_pp = gen_server_config.get(
163-
"tensor_parallel_size",
164-
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
165-
1)
166-
ctx_tp, ctx_pp = ctx_server_config.get(
167-
"tensor_parallel_size",
168-
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
169-
1)
170-
171-
ctx_total_gpus = ctx_tp * ctx_pp
172-
gen_total_gpus = gen_tp * gen_pp
162+
gen_tp, gen_pp, gen_cp = gen_server_config.get(
163+
"tensor_parallel_size", tensor_parallel_size), gen_server_config.get(
164+
"pipeline_parallel_size",
165+
1), gen_server_config.get("context_parallel_size", 1)
166+
ctx_tp, ctx_pp, ctx_cp = ctx_server_config.get(
167+
"tensor_parallel_size", tensor_parallel_size), ctx_server_config.get(
168+
"pipeline_parallel_size",
169+
1), ctx_server_config.get("context_parallel_size", 1)
170+
171+
ctx_total_gpus = ctx_tp * ctx_pp * ctx_cp
172+
gen_total_gpus = gen_tp * gen_pp * gen_cp
173173

174174
ctx_urls = disaggregated_server_config["context_servers"]["urls"]
175175
gen_urls = disaggregated_server_config["generation_servers"]["urls"]
@@ -196,7 +196,7 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
196196
ctx_server_args = ctx_args + [
197197
"--port",
198198
str(port), "--extra_llm_api_options", ctx_server_config_path,
199-
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
199+
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}", f"--cp_size={ctx_cp}"
200200
]
201201
if "max_num_tokens" in ctx_server_config:
202202
ctx_server_args.append(
@@ -219,7 +219,7 @@ def _apply_perf_flags(cfg: Optional[Dict[str, Any]]):
219219
gen_server_args = gen_args + [
220220
"--port",
221221
str(port), "--extra_llm_api_options", gen_server_config_path,
222-
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
222+
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}", f"--cp_size={gen_cp}"
223223
]
224224
if "max_num_tokens" in gen_server_config:
225225
gen_server_args.append(
@@ -853,6 +853,65 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
853853
task = GSM8K(self.MODEL_NAME)
854854
task.evaluate(llm)
855855

856+
@pytest.mark.skip_less_device(4)
857+
def test_auto_dtype_with_helix(self):
858+
kv_cache_config = {
859+
"free_gpu_memory_fraction": 0.5,
860+
"enable_block_reuse": False,
861+
"enable_partial_reuse": False,
862+
"tokens_per_block": 32,
863+
}
864+
ctx_server_config = {
865+
"pipeline_parallel_size": 1,
866+
"tensor_parallel_size": 2,
867+
"context_parallel_size": 1,
868+
"max_batch_size": 8,
869+
"disable_overlap_scheduler": True,
870+
"kv_cache_config": kv_cache_config,
871+
"enable_chunked_prefill": False,
872+
"cuda_graph_config": None,
873+
"cache_transceiver_config": {
874+
"backend": "UCX"
875+
},
876+
}
877+
gen_server_config = {
878+
"tensor_parallel_size": 1,
879+
"pipeline_parallel_size": 1,
880+
"context_parallel_size": 2,
881+
"cp_config": {
882+
"cp_type": "HELIX",
883+
"tokens_per_block": 32
884+
},
885+
"max_batch_size": 8,
886+
"disable_overlap_scheduler": True,
887+
"kv_cache_config": kv_cache_config,
888+
"enable_chunked_prefill": False,
889+
"cuda_graph_config": None,
890+
"cache_transceiver_config": {
891+
"backend": "UCX"
892+
},
893+
}
894+
disaggregated_server_config = {
895+
"hostname": "localhost",
896+
"port": 8000,
897+
"backend": "pytorch",
898+
"context_servers": {
899+
"num_instances": 1,
900+
"urls": ["localhost:8001"]
901+
},
902+
"generation_servers": {
903+
"num_instances": 1,
904+
"urls": ["localhost:8002"]
905+
}
906+
}
907+
with launch_disaggregated_llm(disaggregated_server_config,
908+
ctx_server_config, gen_server_config,
909+
self.MODEL_PATH) as llm:
910+
task = MMLU(self.MODEL_NAME)
911+
task.evaluate(llm)
912+
task = GSM8K(self.MODEL_NAME)
913+
task.evaluate(llm, extra_acc_spec="helix_with_bs8")
914+
856915
@pytest.mark.skip_less_device(2)
857916
@pytest.mark.skip_less_device_memory(60000)
858917
@parametrize_with_ids("mtp_nextn", [0, 2])

0 commit comments

Comments
 (0)