|
| 1 | +""" |
| 2 | +Tests for cp_broadcast functionality in both MPIDist and TorchDist. |
| 3 | +
|
| 4 | +This module tests the context parallelism broadcast operation which is used |
| 5 | +when CP (context parallelism) is enabled (e.g., in Helix parallelism). |
| 6 | +
|
| 7 | +For MPIDist tests, run with mpirun: |
| 8 | + mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v |
| 9 | +
|
| 10 | +For TorchDist tests, see test_ops.py which uses Ray for distributed testing. |
| 11 | +""" |
| 12 | + |
| 13 | +import numpy as np |
| 14 | +import pytest |
| 15 | +import torch |
| 16 | + |
| 17 | +from tensorrt_llm._torch.distributed import MPIDist |
| 18 | +from tensorrt_llm.mapping import Mapping |
| 19 | + |
| 20 | + |
| 21 | +def get_mpi_info(): |
| 22 | + """Get MPI rank and world size, returns (0, 1) if MPI is not available.""" |
| 23 | + try: |
| 24 | + from mpi4py import MPI |
| 25 | + comm = MPI.COMM_WORLD |
| 26 | + return comm.Get_rank(), comm.Get_size() |
| 27 | + except ImportError: |
| 28 | + return 0, 1 |
| 29 | + |
| 30 | + |
| 31 | +def skip_if_not_mpi(): |
| 32 | + """Skip test if not running under MPI with sufficient ranks.""" |
| 33 | + rank, world_size = get_mpi_info() |
| 34 | + if world_size < 2: |
| 35 | + pytest.skip("Test requires at least 2 MPI ranks (run with mpirun -n 2)") |
| 36 | + |
| 37 | + |
| 38 | +class TestMPIDistCpBroadcast: |
| 39 | + """Tests for MPIDist.cp_broadcast functionality.""" |
| 40 | + |
| 41 | + @pytest.fixture(autouse=True) |
| 42 | + def setup(self): |
| 43 | + """Set up MPI environment and mapping for each test.""" |
| 44 | + skip_if_not_mpi() |
| 45 | + self.rank, self.world_size = get_mpi_info() |
| 46 | + |
| 47 | + # Set up mapping with CP enabled (cp_size = world_size, tp_size = 1) |
| 48 | + self.mapping = Mapping( |
| 49 | + world_size=self.world_size, |
| 50 | + rank=self.rank, |
| 51 | + tp_size=1, |
| 52 | + cp_size=self.world_size, |
| 53 | + pp_size=1, |
| 54 | + ) |
| 55 | + self.dist = MPIDist(mapping=self.mapping) |
| 56 | + |
| 57 | + def test_broadcast_numpy_array(self): |
| 58 | + """Test broadcasting a numpy array via cp_broadcast.""" |
| 59 | + root = 0 |
| 60 | + shape = (64, 128) |
| 61 | + |
| 62 | + if self.mapping.cp_rank == root: |
| 63 | + # Root rank creates the data to broadcast |
| 64 | + data = np.random.randn(*shape).astype(np.float32) |
| 65 | + else: |
| 66 | + # Non-root ranks have empty/zero data |
| 67 | + data = np.zeros(shape, dtype=np.float32) |
| 68 | + |
| 69 | + # Store original data from root for verification |
| 70 | + from mpi4py import MPI |
| 71 | + expected = np.zeros(shape, dtype=np.float32) |
| 72 | + MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, |
| 73 | + root=root) |
| 74 | + if self.mapping.cp_rank == root: |
| 75 | + expected = data.copy() |
| 76 | + |
| 77 | + # Perform cp_broadcast |
| 78 | + result = self.dist.cp_broadcast(data, root=root) |
| 79 | + |
| 80 | + # Verify all ranks have the same data |
| 81 | + np.testing.assert_array_almost_equal(result, expected) |
| 82 | + |
| 83 | + def test_broadcast_python_dict(self): |
| 84 | + """Test broadcasting a Python dictionary via cp_broadcast.""" |
| 85 | + root = 0 |
| 86 | + |
| 87 | + if self.mapping.cp_rank == root: |
| 88 | + obj = { |
| 89 | + "model_name": "llama", |
| 90 | + "batch_size": 32, |
| 91 | + "tokens": [1, 2, 3, 4, 5], |
| 92 | + "config": {"hidden_size": 4096, "num_layers": 32} |
| 93 | + } |
| 94 | + else: |
| 95 | + obj = None |
| 96 | + |
| 97 | + result = self.dist.cp_broadcast(obj, root=root) |
| 98 | + |
| 99 | + # Verify all ranks received the correct object |
| 100 | + assert result["model_name"] == "llama" |
| 101 | + assert result["batch_size"] == 32 |
| 102 | + assert result["tokens"] == [1, 2, 3, 4, 5] |
| 103 | + assert result["config"]["hidden_size"] == 4096 |
| 104 | + assert result["config"]["num_layers"] == 32 |
| 105 | + |
| 106 | + def test_broadcast_python_list(self): |
| 107 | + """Test broadcasting a Python list via cp_broadcast.""" |
| 108 | + root = 0 |
| 109 | + |
| 110 | + if self.mapping.cp_rank == root: |
| 111 | + obj = ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] |
| 112 | + else: |
| 113 | + obj = None |
| 114 | + |
| 115 | + result = self.dist.cp_broadcast(obj, root=root) |
| 116 | + |
| 117 | + assert result == ["request1", "request2", {"id": 123, "data": [1, 2, 3]}] |
| 118 | + |
| 119 | + def test_broadcast_from_non_zero_root(self): |
| 120 | + """Test broadcasting from a non-zero root rank.""" |
| 121 | + if self.world_size < 2: |
| 122 | + pytest.skip("Need at least 2 ranks to test non-zero root") |
| 123 | + |
| 124 | + root = 1 # Broadcast from rank 1 |
| 125 | + |
| 126 | + if self.mapping.cp_rank == root: |
| 127 | + obj = {"source": "rank1", "value": 42} |
| 128 | + else: |
| 129 | + obj = None |
| 130 | + |
| 131 | + result = self.dist.cp_broadcast(obj, root=root) |
| 132 | + |
| 133 | + assert result["source"] == "rank1" |
| 134 | + assert result["value"] == 42 |
| 135 | + |
| 136 | + def test_broadcast_large_object(self): |
| 137 | + """Test broadcasting a large object that may require chunking.""" |
| 138 | + root = 0 |
| 139 | + # Create a large list to test chunking behavior |
| 140 | + large_size = 100000 |
| 141 | + |
| 142 | + if self.mapping.cp_rank == root: |
| 143 | + obj = list(range(large_size)) |
| 144 | + else: |
| 145 | + obj = None |
| 146 | + |
| 147 | + result = self.dist.cp_broadcast(obj, root=root) |
| 148 | + |
| 149 | + assert len(result) == large_size |
| 150 | + assert result[0] == 0 |
| 151 | + assert result[-1] == large_size - 1 |
| 152 | + |
| 153 | + def test_broadcast_string(self): |
| 154 | + """Test broadcasting a simple string via cp_broadcast.""" |
| 155 | + root = 0 |
| 156 | + |
| 157 | + if self.mapping.cp_rank == root: |
| 158 | + obj = "Hello from root rank!" |
| 159 | + else: |
| 160 | + obj = None |
| 161 | + |
| 162 | + result = self.dist.cp_broadcast(obj, root=root) |
| 163 | + |
| 164 | + assert result == "Hello from root rank!" |
| 165 | + |
| 166 | + |
| 167 | +# Additional integration-style test that can be run standalone |
| 168 | +def test_mpi_cp_broadcast_integration(): |
| 169 | + """ |
| 170 | + Integration test for MPIDist cp_broadcast. |
| 171 | +
|
| 172 | + Run with: mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py::test_mpi_cp_broadcast_integration -v |
| 173 | + """ |
| 174 | + rank, world_size = get_mpi_info() |
| 175 | + if world_size < 2: |
| 176 | + pytest.skip("Test requires at least 2 MPI ranks") |
| 177 | + |
| 178 | + # Create mapping with CP enabled |
| 179 | + mapping = Mapping( |
| 180 | + world_size=world_size, |
| 181 | + rank=rank, |
| 182 | + tp_size=1, |
| 183 | + cp_size=world_size, |
| 184 | + pp_size=1, |
| 185 | + ) |
| 186 | + dist = MPIDist(mapping=mapping) |
| 187 | + |
| 188 | + # Test 1: Broadcast dict |
| 189 | + if mapping.cp_rank == 0: |
| 190 | + payload = {"requests": [{"id": i} for i in range(10)]} |
| 191 | + else: |
| 192 | + payload = None |
| 193 | + |
| 194 | + result = dist.cp_broadcast(payload, root=0) |
| 195 | + assert len(result["requests"]) == 10 |
| 196 | + assert result["requests"][0]["id"] == 0 |
| 197 | + |
| 198 | + # Test 2: Broadcast numpy array |
| 199 | + shape = (32, 64) |
| 200 | + if mapping.cp_rank == 0: |
| 201 | + arr = np.ones(shape, dtype=np.float32) * (rank + 1) |
| 202 | + else: |
| 203 | + arr = np.zeros(shape, dtype=np.float32) |
| 204 | + |
| 205 | + result = dist.cp_broadcast(arr, root=0) |
| 206 | + expected_val = 1.0 # From rank 0 |
| 207 | + np.testing.assert_array_almost_equal(result, np.ones(shape) * expected_val) |
| 208 | + |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + # Allow running directly with mpirun |
| 212 | + pytest.main([__file__, "-v"]) |
| 213 | + |
0 commit comments