Skip to content

Commit 658663e

Browse files
committed
add unit tests
1 parent db62a2b commit 658663e

File tree

2 files changed

+372
-0
lines changed

2 files changed

+372
-0
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Tests for cp_broadcast functionality in both MPIDist and TorchDist.
3+
4+
This module tests the context parallelism broadcast operation which is used
5+
when CP (context parallelism) is enabled (e.g., in Helix parallelism).
6+
7+
For MPIDist tests, run with mpirun:
8+
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
9+
10+
For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
11+
"""
12+
13+
import numpy as np
14+
import pytest
15+
import torch
16+
17+
from tensorrt_llm._torch.distributed import MPIDist
18+
from tensorrt_llm.mapping import Mapping
19+
20+
21+
def get_mpi_info():
22+
"""Get MPI rank and world size, returns (0, 1) if MPI is not available."""
23+
try:
24+
from mpi4py import MPI
25+
comm = MPI.COMM_WORLD
26+
return comm.Get_rank(), comm.Get_size()
27+
except ImportError:
28+
return 0, 1
29+
30+
31+
def skip_if_not_mpi():
32+
"""Skip test if not running under MPI with sufficient ranks."""
33+
rank, world_size = get_mpi_info()
34+
if world_size < 2:
35+
pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)")
36+
37+
38+
class TestMPIDistCpBroadcast:
39+
"""Tests for MPIDist.cp_broadcast functionality."""
40+
41+
@pytest.fixture(autouse=True)
42+
def setup(self):
43+
"""Set up MPI environment and mapping for each test."""
44+
skip_if_not_mpi()
45+
self.rank, self.world_size = get_mpi_info()
46+
47+
# Set up mapping with CP enabled (cp_size = world_size, tp_size = 1)
48+
self.mapping = Mapping(
49+
world_size=self.world_size,
50+
rank=self.rank,
51+
tp_size=1,
52+
cp_size=self.world_size,
53+
pp_size=1,
54+
)
55+
self.dist = MPIDist(mapping=self.mapping)
56+
57+
def test_broadcast_numpy_array(self):
58+
"""Test broadcasting a numpy array via cp_broadcast."""
59+
root = 0
60+
shape = (64, 128)
61+
62+
if self.mapping.cp_rank == root:
63+
# Root rank creates the data to broadcast
64+
data = np.random.randn(*shape).astype(np.float32)
65+
else:
66+
# Non-root ranks have empty/zero data
67+
data = np.zeros(shape, dtype=np.float32)
68+
69+
# Store original data from root for verification
70+
from mpi4py import MPI
71+
expected = np.zeros(shape, dtype=np.float32)
72+
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected,
73+
root=root)
74+
if self.mapping.cp_rank == root:
75+
expected = data.copy()
76+
77+
# Perform cp_broadcast
78+
result = self.dist.cp_broadcast(data, root=root)
79+
80+
# Verify all ranks have the same data
81+
np.testing.assert_array_almost_equal(result, expected)
82+
83+
def test_broadcast_python_dict(self):
84+
"""Test broadcasting a Python dictionary via cp_broadcast."""
85+
root = 0
86+
87+
if self.mapping.cp_rank == root:
88+
obj = {
89+
"model_name": "llama",
90+
"batch_size": 32,
91+
"tokens": [1, 2, 3, 4, 5],
92+
"config": {"hidden_size": 4096, "num_layers": 32}
93+
}
94+
else:
95+
obj = None
96+
97+
result = self.dist.cp_broadcast(obj, root=root)
98+
99+
# Verify all ranks received the correct object
100+
assert result["model_name"] == "llama"
101+
assert result["batch_size"] == 32
102+
assert result["tokens"] == [1, 2, 3, 4, 5]
103+
assert result["config"]["hidden_size"] == 4096
104+
assert result["config"]["num_layers"] == 32
105+
106+
def test_broadcast_python_list(self):
107+
"""Test broadcasting a Python list via cp_broadcast."""
108+
root = 0
109+
110+
if self.mapping.cp_rank == root:
111+
obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
112+
else:
113+
obj = None
114+
115+
result = self.dist.cp_broadcast(obj, root=root)
116+
117+
assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}]
118+
119+
def test_broadcast_from_non_zero_root(self):
120+
"""Test broadcasting from a non-zero root rank."""
121+
if self.world_size < 2:
122+
pytest.skip("Need at least 2 ranks to test non-zero root")
123+
124+
root = 1 # Broadcast from rank 1
125+
126+
if self.mapping.cp_rank == root:
127+
obj = {"source": "rank1", "value": 42}
128+
else:
129+
obj = None
130+
131+
result = self.dist.cp_broadcast(obj, root=root)
132+
133+
assert result["source"] == "rank1"
134+
assert result["value"] == 42
135+
136+
def test_broadcast_large_object(self):
137+
"""Test broadcasting a large object that may require chunking."""
138+
root = 0
139+
# Create a large list to test chunking behavior
140+
large_size = 100000
141+
142+
if self.mapping.cp_rank == root:
143+
obj = list(range(large_size))
144+
else:
145+
obj = None
146+
147+
result = self.dist.cp_broadcast(obj, root=root)
148+
149+
assert len(result) == large_size
150+
assert result[0] == 0
151+
assert result[-1] == large_size - 1
152+
153+
def test_broadcast_string(self):
154+
"""Test broadcasting a simple string via cp_broadcast."""
155+
root = 0
156+
157+
if self.mapping.cp_rank == root:
158+
obj = "Hello from root rank!"
159+
else:
160+
obj = None
161+
162+
result = self.dist.cp_broadcast(obj, root=root)
163+
164+
assert result == "Hello from root rank!"
165+
166+
167+
# Additional integration-style test that can be run standalone
168+
def test_mpi_cp_broadcast_integration():
169+
"""
170+
Integration test for MPIDist cp_broadcast.
171+
172+
Run with: mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py::test_mpi_cp_broadcast_integration -v
173+
"""
174+
rank, world_size = get_mpi_info()
175+
if world_size < 2:
176+
pytest.skip("Test requires at least 2 MPI ranks")
177+
178+
# Create mapping with CP enabled
179+
mapping = Mapping(
180+
world_size=world_size,
181+
rank=rank,
182+
tp_size=1,
183+
cp_size=world_size,
184+
pp_size=1,
185+
)
186+
dist = MPIDist(mapping=mapping)
187+
188+
# Test 1: Broadcast dict
189+
if mapping.cp_rank == 0:
190+
payload = {"requests": [{"id": i} for i in range(10)]}
191+
else:
192+
payload = None
193+
194+
result = dist.cp_broadcast(payload, root=0)
195+
assert len(result["requests"]) == 10
196+
assert result["requests"][0]["id"] == 0
197+
198+
# Test 2: Broadcast numpy array
199+
shape = (32, 64)
200+
if mapping.cp_rank == 0:
201+
arr = np.ones(shape, dtype=np.float32) * (rank + 1)
202+
else:
203+
arr = np.zeros(shape, dtype=np.float32)
204+
205+
result = dist.cp_broadcast(arr, root=0)
206+
expected_val = 1.0 # From rank 0
207+
np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val)
208+
209+
210+
if __name__ == "__main__":
211+
# Allow running directly with mpirun
212+
pytest.main([__file__, "-v"])
213+

tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,162 @@ def test_allreduce_pg_op(setup_ray_cluster, seq_len, hidden_size):
258258
])
259259
for r in results:
260260
assert r is True
261+
262+
263+
@ray.remote(num_gpus=1)
264+
class CpBroadcastTest:
265+
"""Test worker for cp_broadcast operations with context parallelism."""
266+
267+
def __init__(self, rank, world_size, tp_size, cp_size):
268+
self.rank = rank
269+
self.world_size = world_size
270+
self.tp_size = tp_size
271+
self.cp_size = cp_size
272+
self.master_address = os.environ["MASTER_ADDR"]
273+
274+
assert len(ray.get_gpu_ids()) == 1
275+
self.gpu = int(ray.get_gpu_ids()[0])
276+
from tensorrt_llm.executor.ray_gpu_worker import RayWorkerWrapper
277+
local_gpu = RayWorkerWrapper.physical_to_local_id(self.gpu)
278+
torch.cuda.set_device(local_gpu)
279+
280+
def _create_tcp_store(self,
281+
port: Optional[int] = None
282+
) -> torch.distributed.TCPStore:
283+
actual_port = port if port is not None else 0
284+
return torch.distributed.TCPStore(host_name=self.master_address,
285+
port=actual_port,
286+
world_size=self.world_size,
287+
is_master=(self.rank == 0),
288+
wait_for_workers=False)
289+
290+
def setup_tcp_store(self):
291+
if self.rank != 0:
292+
raise RuntimeError("Only the master worker can setup TCP store")
293+
self.store = self._create_tcp_store()
294+
return self.store.port
295+
296+
def setup_distributed_env(self, port: int):
297+
if self.rank != 0:
298+
self.store = self._create_tcp_store(port)
299+
300+
torch.distributed.init_process_group(backend="cuda:nccl,cpu:gloo",
301+
store=self.store,
302+
world_size=self.world_size,
303+
rank=self.rank)
304+
self.mapping = Mapping(world_size=self.world_size,
305+
gpus_per_node=self.world_size,
306+
tp_size=self.tp_size,
307+
cp_size=self.cp_size,
308+
rank=self.rank)
309+
self.dist = TorchDist(self.mapping)
310+
311+
def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0):
312+
"""Test broadcasting a tensor via cp_broadcast."""
313+
cp_rank = self.mapping.cp_rank
314+
if cp_rank == root:
315+
# Root rank has the tensor to broadcast
316+
tensor = root_tensor.cuda()
317+
else:
318+
# Non-root ranks start with zeros
319+
tensor = torch.zeros_like(root_tensor).cuda()
320+
321+
result = self.dist.cp_broadcast(tensor, root=root)
322+
323+
# After broadcast, all CP ranks should have the same tensor
324+
expected = root_tensor.cuda()
325+
return torch.allclose(result, expected)
326+
327+
def run_object_broadcast(self, root_obj, root: int = 0):
328+
"""Test broadcasting a non-tensor object via cp_broadcast."""
329+
cp_rank = self.mapping.cp_rank
330+
if cp_rank == root:
331+
obj = root_obj
332+
else:
333+
obj = None
334+
335+
result = self.dist.cp_broadcast(obj, root=root)
336+
337+
# After broadcast, all CP ranks should have the same object
338+
return result == root_obj
339+
340+
341+
@pytest.mark.gpu2
342+
@pytest.mark.parametrize("hidden_size", [128, 512],
343+
ids=lambda x: f"hidden:{x}")
344+
@pytest.mark.parametrize("seq_len", [16, 32], ids=lambda x: f"seqlen:{x}")
345+
def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
346+
"""Test TorchDist.cp_broadcast with tensor data."""
347+
torch.manual_seed(42)
348+
dtype = torch.bfloat16
349+
world_size = 2
350+
tp_size = 1
351+
cp_size = 2 # Enable context parallelism
352+
353+
# Create tensor to broadcast from root
354+
root_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
355+
356+
runtime_env = ray.runtime_env.RuntimeEnv()
357+
runtime_env["env_vars"] = os.environ.copy()
358+
runtime_env["env_vars"].update({
359+
"TLLM_DISABLE_MPI": "1",
360+
"MASTER_ADDR": "127.0.0.1",
361+
})
362+
363+
remote_tests = []
364+
for rank in range(world_size):
365+
remote_tests.append(
366+
CpBroadcastTest.options(runtime_env=runtime_env).remote(
367+
rank, world_size, tp_size, cp_size))
368+
369+
ray.get([test.__ray_ready__.remote() for test in remote_tests])
370+
371+
port = ray.get(remote_tests[0].setup_tcp_store.remote())
372+
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
373+
374+
# Test broadcasting from root=0
375+
results = ray.get([
376+
test.run_tensor_broadcast.remote(root_tensor, root=0)
377+
for test in remote_tests
378+
])
379+
for r in results:
380+
assert r is True, "Tensor broadcast from root=0 failed"
381+
382+
383+
@pytest.mark.gpu2
384+
@pytest.mark.parametrize("test_object", [
385+
{"key1": "value1", "key2": [1, 2, 3]},
386+
["item1", "item2", {"nested": True}],
387+
"simple_string",
388+
], ids=["dict", "list", "string"])
389+
def test_cp_broadcast_object(setup_ray_cluster, test_object):
390+
"""Test TorchDist.cp_broadcast with non-tensor objects."""
391+
world_size = 2
392+
tp_size = 1
393+
cp_size = 2 # Enable context parallelism
394+
395+
runtime_env = ray.runtime_env.RuntimeEnv()
396+
runtime_env["env_vars"] = os.environ.copy()
397+
runtime_env["env_vars"].update({
398+
"TLLM_DISABLE_MPI": "1",
399+
"MASTER_ADDR": "127.0.0.1",
400+
})
401+
402+
remote_tests = []
403+
for rank in range(world_size):
404+
remote_tests.append(
405+
CpBroadcastTest.options(runtime_env=runtime_env).remote(
406+
rank, world_size, tp_size, cp_size))
407+
408+
ray.get([test.__ray_ready__.remote() for test in remote_tests])
409+
410+
port = ray.get(remote_tests[0].setup_tcp_store.remote())
411+
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
412+
413+
# Test broadcasting object from root=0
414+
results = ray.get([
415+
test.run_object_broadcast.remote(test_object, root=0)
416+
for test in remote_tests
417+
])
418+
for r in results:
419+
assert r is True, f"Object broadcast from root=0 failed for {type(test_object)}"

0 commit comments

Comments
 (0)