Skip to content

Commit 9ea59ed

Browse files
committed
formatting
1 parent cc1bb16 commit 9ea59ed

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

tests/unittest/_torch/distributed/test_cp_broadcast.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
when CP (context parallelism) is enabled (e.g., in Helix parallelism).
66
77
For MPIDist tests, run with mpirun:
8-
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
8+
mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py -v
99
1010
For TorchDist tests, see test_ops.py which uses Ray for distributed testing.
1111
"""
1212

1313
import numpy as np
1414
import pytest
15-
import torch
1615

1716
from tensorrt_llm._torch.distributed import MPIDist
1817
from tensorrt_llm.mapping import Mapping
@@ -22,6 +21,7 @@ def get_mpi_info():
2221
"""Get MPI rank and world size, returns (0, 1) if MPI is not available."""
2322
try:
2423
from mpi4py import MPI
24+
2525
comm = MPI.COMM_WORLD
2626
return comm.Get_rank(), comm.Get_size()
2727
except ImportError:
@@ -68,9 +68,9 @@ def test_broadcast_numpy_array(self):
6868

6969
# Store original data from root for verification
7070
from mpi4py import MPI
71+
7172
expected = np.zeros(shape, dtype=np.float32)
72-
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected,
73-
root=root)
73+
MPI.COMM_WORLD.Bcast(data if self.mapping.cp_rank == root else expected, root=root)
7474
if self.mapping.cp_rank == root:
7575
expected = data.copy()
7676

@@ -89,7 +89,7 @@ def test_broadcast_python_dict(self):
8989
"model_name": "llama",
9090
"batch_size": 32,
9191
"tokens": [1, 2, 3, 4, 5],
92-
"config": {"hidden_size": 4096, "num_layers": 32}
92+
"config": {"hidden_size": 4096, "num_layers": 32},
9393
}
9494
else:
9595
obj = None
@@ -168,8 +168,6 @@ def test_broadcast_string(self):
168168
def test_mpi_cp_broadcast_integration():
169169
"""
170170
Integration test for MPIDist cp_broadcast.
171-
172-
Run with: mpirun -n 2 python -m pytest tests/unittest/_torch/distributed/test_cp_broadcast.py::test_mpi_cp_broadcast_integration -v
173171
"""
174172
rank, world_size = get_mpi_info()
175173
if world_size < 2:
@@ -210,4 +208,3 @@ def test_mpi_cp_broadcast_integration():
210208
if __name__ == "__main__":
211209
# Allow running directly with mpirun
212210
pytest.main([__file__, "-v"])
213-

tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,7 @@ def run_object_broadcast(self, root_obj, root: int = 0):
339339

340340

341341
@pytest.mark.gpu2
342-
@pytest.mark.parametrize("hidden_size", [128, 512],
343-
ids=lambda x: f"hidden:{x}")
342+
@pytest.mark.parametrize("hidden_size", [128, 512], ids=lambda x: f"hidden:{x}")
344343
@pytest.mark.parametrize("seq_len", [16, 32], ids=lambda x: f"seqlen:{x}")
345344
def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
346345
"""Test TorchDist.cp_broadcast with tensor data."""
@@ -382,10 +381,16 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
382381

383382
@pytest.mark.gpu2
384383
@pytest.mark.parametrize("test_object", [
385-
{"key1": "value1", "key2": [1, 2, 3]},
386-
["item1", "item2", {"nested": True}],
384+
{
385+
"key1": "value1",
386+
"key2": [1, 2, 3]
387+
},
388+
["item1", "item2", {
389+
"nested": True
390+
}],
387391
"simple_string",
388-
], ids=["dict", "list", "string"])
392+
],
393+
ids=["dict", "list", "string"])
389394
def test_cp_broadcast_object(setup_ray_cluster, test_object):
390395
"""Test TorchDist.cp_broadcast with non-tensor objects."""
391396
world_size = 2

0 commit comments

Comments
 (0)