Skip to content

Commit f16ee9b

Browse files
committed
address comments from Yuxian
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent e9af2f3 commit f16ee9b

File tree

5 files changed

+146
-7
lines changed

5 files changed

+146
-7
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1506,7 +1506,7 @@ def _broadcast_cache_data(
15061506
"""Broadcast tactics from root rank to all other ranks."""
15071507
cache_data = self.profiling_cache.get_specific_custom_op(custom_op)
15081508
root = 0
1509-
cache_data = self._dist.tp_broadcast(obj=cache_data, root=root)
1509+
cache_data = self._dist.tp_cp_broadcast(obj=cache_data, root=root)
15101510

15111511
self.profiling_cache.merge_cache_data(cache_data)
15121512

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,18 @@ def cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
411411
comm = self.cp_comm
412412
return safe_broadcast(comm, obj, root=root, chunk_size=chunk_size)
413413

414+
def tp_cp_broadcast(self, obj, root=0, chunk_size: int = 4 * 1024 * 1024):
415+
"""Broadcast object across both TP and CP groups.
416+
417+
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
418+
First broadcasts within the TP group, then within the CP group.
419+
"""
420+
if self.tp_size > 1:
421+
obj = self.tp_broadcast(obj, root=root, chunk_size=chunk_size)
422+
if self.cp_size > 1:
423+
obj = self.cp_broadcast(obj, root=root, chunk_size=chunk_size)
424+
return obj
425+
414426
def tp_allgather(self, obj):
415427
return self.tp_comm.allgather(obj)
416428

@@ -730,6 +742,19 @@ def cp_broadcast(self, obj, root=0):
730742
device=torch.device("cpu"))
731743
return ret[0]
732744

745+
@log_op
746+
def tp_cp_broadcast(self, obj, root=0):
747+
"""Broadcast object across both TP and CP groups.
748+
749+
This is used when both TP and CP parallelism are enabled (e.g., helix parallelism).
750+
First broadcasts within the TP group, then within the CP group.
751+
"""
752+
if self.tp_size > 1:
753+
obj = self.tp_broadcast(obj, root=root)
754+
if self.cp_size > 1:
755+
obj = self.cp_broadcast(obj, root=root)
756+
return obj
757+
733758
@log_op
734759
def pp_allgather(self, obj):
735760
if isinstance(obj, torch.Tensor):

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,12 +590,7 @@ def _broadcast_new_requests(
590590
# Broadcast within first PP stage before send/recv chain to other PP stages.
591591
# This needs to cover both TP and CP ranks within the first PP stage.
592592
if self.dist.is_first_pp_rank:
593-
if self.dist.tp_size > 1:
594-
payloads = self.dist.tp_broadcast(payloads, root=0)
595-
# Also broadcast within CP group when CP is enabled (helix parallelism).
596-
# This ensures all CP ranks within the first PP stage receive the requests.
597-
if self.dist.cp_size > 1:
598-
payloads = self.dist.cp_broadcast(payloads, root=0)
593+
payloads = self.dist.tp_cp_broadcast(payloads, root=0)
599594

600595
# Tag for communication
601596
tag = self.dist.pp_size # Use pp_size as tag to avoid conflicts

tests/unittest/_torch/distributed/test_cp_broadcast.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,58 @@ def test_mpi_cp_broadcast_integration():
208208
if __name__ == "__main__":
209209
# Allow running directly with mpirun
210210
pytest.main([__file__, "-v"])
211+
212+
213+
class TestMPIDistTpCpBroadcast:
214+
"""Tests for MPIDist.tp_cp_broadcast functionality."""
215+
216+
@pytest.fixture(autouse=True)
217+
def setup(self):
218+
"""Set up MPI environment and mapping for each test."""
219+
skip_if_not_mpi()
220+
self.rank, self.world_size = get_mpi_info()
221+
222+
# Set up mapping with both TP and CP enabled
223+
# For 2 ranks: tp_size=1, cp_size=2 (tp_cp_broadcast will only do cp_broadcast)
224+
self.mapping = Mapping(
225+
world_size=self.world_size,
226+
rank=self.rank,
227+
tp_size=1,
228+
cp_size=self.world_size,
229+
pp_size=1,
230+
)
231+
self.dist = MPIDist(mapping=self.mapping)
232+
233+
def test_tp_cp_broadcast_python_dict(self):
234+
"""Test broadcasting a Python dictionary via tp_cp_broadcast."""
235+
root = 0
236+
237+
# Only rank 0 in both TP and CP groups should have the object
238+
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
239+
obj = {
240+
"model_name": "llama",
241+
"batch_size": 32,
242+
"tokens": [1, 2, 3, 4, 5],
243+
}
244+
else:
245+
obj = None
246+
247+
result = self.dist.tp_cp_broadcast(obj, root=root)
248+
249+
# Verify all ranks received the correct object
250+
assert result["model_name"] == "llama"
251+
assert result["batch_size"] == 32
252+
assert result["tokens"] == [1, 2, 3, 4, 5]
253+
254+
def test_tp_cp_broadcast_python_list(self):
255+
"""Test broadcasting a Python list via tp_cp_broadcast."""
256+
root = 0
257+
258+
if self.mapping.tp_rank == root and self.mapping.cp_rank == root:
259+
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
260+
else:
261+
obj = None
262+
263+
result = self.dist.tp_cp_broadcast(obj, root=root)
264+
265+
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]

tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,21 @@ def run_object_broadcast(self, root_obj, root: int = 0):
337337
# After broadcast, all CP ranks should have the same object
338338
return result == root_obj
339339

340+
def run_tp_cp_broadcast(self, root_obj, root: int = 0):
341+
"""Test broadcasting an object via tp_cp_broadcast."""
342+
# For tp_cp_broadcast, only rank 0 in both TP and CP should have the object
343+
tp_rank = self.mapping.tp_rank
344+
cp_rank = self.mapping.cp_rank
345+
if tp_rank == root and cp_rank == root:
346+
obj = root_obj
347+
else:
348+
obj = None
349+
350+
result = self.dist.tp_cp_broadcast(obj, root=root)
351+
352+
# After broadcast, all TP and CP ranks should have the same object
353+
return result == root_obj
354+
340355

341356
@pytest.mark.gpu2
342357
@pytest.mark.parametrize("hidden_size", [128, 512], ids=lambda x: f"hidden:{x}")
@@ -422,3 +437,52 @@ def test_cp_broadcast_object(setup_ray_cluster, test_object):
422437
])
423438
for r in results:
424439
assert r is True, f"Object broadcast from root=0 failed for {type(test_object)}"
440+
441+
442+
@pytest.mark.gpu2
443+
@pytest.mark.parametrize("test_object", [
444+
{
445+
"key1": "value1",
446+
"key2": [1, 2, 3]
447+
},
448+
["item1", "item2", {
449+
"nested": True
450+
}],
451+
"simple_string",
452+
],
453+
ids=["dict", "list", "string"])
454+
def test_tp_cp_broadcast(setup_ray_cluster, test_object):
455+
"""Test TorchDist.tp_cp_broadcast with various objects.
456+
457+
This tests the combined TP+CP broadcast which is used when both tensor
458+
and context parallelism are enabled (e.g., helix parallelism).
459+
"""
460+
world_size = 2
461+
tp_size = 1
462+
cp_size = 2 # Enable context parallelism (tp_cp_broadcast will only do cp_broadcast)
463+
464+
runtime_env = ray.runtime_env.RuntimeEnv()
465+
runtime_env["env_vars"] = os.environ.copy()
466+
runtime_env["env_vars"].update({
467+
"TLLM_DISABLE_MPI": "1",
468+
"MASTER_ADDR": "127.0.0.1",
469+
})
470+
471+
remote_tests = []
472+
for rank in range(world_size):
473+
remote_tests.append(
474+
CpBroadcastTest.options(runtime_env=runtime_env).remote(
475+
rank, world_size, tp_size, cp_size))
476+
477+
ray.get([test.__ray_ready__.remote() for test in remote_tests])
478+
479+
port = ray.get(remote_tests[0].setup_tcp_store.remote())
480+
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
481+
482+
# Test tp_cp_broadcast from root=0
483+
results = ray.get([
484+
test.run_tp_cp_broadcast.remote(test_object, root=0)
485+
for test in remote_tests
486+
])
487+
for r in results:
488+
assert r is True, f"tp_cp_broadcast from root=0 failed for {type(test_object)}"

0 commit comments

Comments
 (0)