Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1506,7 +1506,7 @@ def _broadcast_cache_data(
"""Broadcast tactics from root rank to all other ranks."""
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
root = 0
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root)

self.profiling_cache.merge_cache_data(cache_data)

Expand Down
50 changes: 48 additions & 2 deletions tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,26 @@ def broadcast(self, obj, root=0):
def allgather(self, obj, root=0):
pass

@abstractmethod
def tp_broadcast(self, obj, root=0, **kwargs):
pass

@abstractmethod
def cp_broadcast(self, obj, root=0, **kwargs):
pass

def tp_cp_broadcast(self, obj, root=0, **kwargs):
"""Broadcast object across both TP and CP groups.

This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
First broadcasts within the TP group, then within the CP group.
"""
if self.tp_size > 1:
obj = self.tp_broadcast(obj, root=root, **kwargs)
if self.cp_size > 1:
obj = self.cp_broadcast(obj, root=root, **kwargs)
return obj


def safe_broadcast(comm, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
"""
Expand Down Expand Up @@ -407,14 +427,26 @@ def create_cp_comm(self):
def cp_allgather(self, obj):
return self.cp_comm.allgather(obj)

def cp_broadcast(self,
obj,
root=0,
chunk_size: int = 4 * 1024 * 1024,
**kwargs):
comm = self.cp_comm
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)

def tp_allgather(self, obj):
return self.tp_comm.allgather(obj)

def tp_gather(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
comm = self.tp_comm
return safe_gather(comm, obj, root=root, chunk_size=chunk_size)

def tp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
def tp_broadcast(self,
obj,
root=0,
chunk_size: int = 4 * 1024 * 1024,
**kwargs):
comm = self.tp_comm
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)

Expand Down Expand Up @@ -699,7 +731,7 @@ def tp_gather(self, obj, dst=0):
return output_list

@log_op
def tp_broadcast(self, obj, root=0):
def tp_broadcast(self, obj, root=0, **kwargs):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.tp_group_pg)
return obj
Expand All @@ -712,6 +744,20 @@ def tp_broadcast(self, obj, root=0):
device=torch.device("cpu"))
return ret[0]

@log_op
def cp_broadcast(self, obj, root=0, **kwargs):
if isinstance(obj, torch.Tensor):
dist.broadcast(obj, src=root, group=self.mapping.cp_group_pg)
return obj
else:
ret = [obj]
torch.distributed.broadcast_object_list(
ret,
src=root,
group=self.mapping.cp_group_pg,
device=torch.device("cpu"))
return ret[0]

@log_op
def pp_allgather(self, obj):
if isinstance(obj, torch.Tensor):
Expand Down
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,9 +587,10 @@ def _broadcast_new_requests(
if not self.dist.has_pp:
return self.dist.broadcast(payloads, root=0)

# Broadcast within first tp group before send/recv chain to other tp groups
if self.dist.tp_size > 1 and self.dist.is_first_pp_rank:
payloads = self.dist.tp_broadcast(payloads, root=0)
# Broadcast within first PP stage before send/recv chain to other PP stages.
# This needs to cover both TP and CP ranks within the first PP stage.
if self.dist.is_first_pp_rank:
payloads = self.dist.tp_cp_broadcast(payloads, root=0)

# Tag for communication
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts
Expand Down
23 changes: 15 additions & 8 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,9 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("gen_pp,gen_tp,gen_cp", [(1, 2, 2), (2, 1, 2)],
ids=["pp1tp2cp2", "pp2tp1cp2"])
@pytest.mark.parametrize("cuda_graph_config", [
None,
{
Expand All @@ -888,8 +890,10 @@ def test_auto_dtype(self, overlap_scheduler, mtp_nextn):
"cudagraph:with_padding"
])
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
gen_pp, gen_tp, gen_cp):
use_nccl_for_alltoall = comms_medium == "nccl"
gen_ep = gen_tp * gen_cp
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
"enable_block_reuse": False,
Expand All @@ -898,20 +902,22 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
}
ctx_server_config = {
"pipeline_parallel_size": 1,
"tensor_parallel_size": 2,
"tensor_parallel_size": 4,
"context_parallel_size": 1,
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,
"enable_chunked_prefill": False,
"cuda_graph_config": None,
"cache_transceiver_config": {
"backend": "UCX"
"backend": "UCX",
"max_tokens_in_buffer": 8192,
},
}
gen_server_config = {
"tensor_parallel_size": 1,
"pipeline_parallel_size": 1,
"context_parallel_size": 2,
"tensor_parallel_size": gen_tp,
"pipeline_parallel_size": gen_pp,
"context_parallel_size": gen_cp,
"moe_expert_parallel_size": gen_ep,
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32,
Expand All @@ -922,7 +928,8 @@ def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config):
"enable_chunked_prefill": False,
"cuda_graph_config": cuda_graph_config,
"cache_transceiver_config": {
"backend": "UCX"
"backend": "UCX",
"max_tokens_in_buffer": 8192,
},
}
disaggregated_server_config = {
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,14 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_w4a8_mxfp4[fp8-latency]
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
Expand All @@ -92,6 +93,7 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[baseline_fp8kv] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_fp8_blockscale[latency] TIMEOUT (60)
Expand Down
6 changes: 0 additions & 6 deletions tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ l0_gb200_multi_gpus:
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[no_cuda_graph_overlap-cutlass]
- accuracy/test_llm_api_pytorch.py::TestQwen3NextInstruct::test_nvfp4[tp4ep4-trtllm]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:none]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:none]
- accuracy/test_llm_api_pytorch.py::TestMistralLarge3_675B::test_nvfp4_4gpus[latency_moe_trtllm_eagle] TIMEOUT (90)
- condition:
ranges:
Expand All @@ -89,10 +87,6 @@ l0_gb200_multi_gpus:
stage: post_merge
backend: pytorch
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:without_padding]
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
Expand Down
Loading