55when CP (context parallelism) is enabled (e.g., in Helix parallelism).
66
77For MPIDist tests, run with mpirun:
8- mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
8+ mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
99
1010For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
1111"""
1212
1313import numpy as np
1414import pytest
15- import torch
1615
1716from tensorrt_llm ._torch .distributed import MPIDist
1817from tensorrt_llm .mapping import Mapping
@@ -22,6 +21,7 @@ def get_mpi_info():
2221 """Get MPI rank and world size, returns (0, 1) if MPI is not available."""
2322 try :
2423 from mpi4py import MPI
24+
2525 comm = MPI .COMM_WORLD
2626 return comm .Get_rank (), comm .Get_size ()
2727 except ImportError :
@@ -68,9 +68,9 @@ def test_broadcast_numpy_array(self):
6868
6969 # Store original data from root for verification
7070 from mpi4py import MPI
71+
7172 expected = np .zeros (shape , dtype = np .float32 )
72- MPI .COMM_WORLD .Bcast (data if self .mapping .cp_rank == root else expected ,
73- root = root )
73+ MPI .COMM_WORLD .Bcast (data if self .mapping .cp_rank == root else expected , root = root )
7474 if self .mapping .cp_rank == root :
7575 expected = data .copy ()
7676
@@ -89,7 +89,7 @@ def test_broadcast_python_dict(self):
8989 "model_name" : "llama" ,
9090 "batch_size" : 32 ,
9191 "tokens" : [1 , 2 , 3 , 4 , 5 ],
92- "config" : {"hidden_size" : 4096 , "num_layers" : 32 }
92+ "config" : {"hidden_size" : 4096 , "num_layers" : 32 },
9393 }
9494 else :
9595 obj = None
@@ -168,8 +168,6 @@ def test_broadcast_string(self):
168168def test_mpi_cp_broadcast_integration ():
169169 """
170170 Integration test for MPIDist cp_broadcast.
171-
172- Run with: mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py::test_mpi_cp_broadcast_integration -v
173171 """
174172 rank , world_size = get_mpi_info ()
175173 if world_size < 2 :
@@ -210,4 +208,3 @@ def test_mpi_cp_broadcast_integration():
210208if __name__ == "__main__" :
211209 # Allow running directly with mpirun
212210 pytest .main ([__file__ , "-v" ])
213-
0 commit comments