Skip to content

Commit 4b78b48

Browse files
committed
remove duplication in test
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent f16ee9b commit 4b78b48

File tree

1 file changed

+22
-61
lines changed
  • tests/unittest/_torch/ray_orchestrator/multi_gpu

1 file changed

+22
-61
lines changed

tests/unittest/_torch/ray_orchestrator/multi_gpu/test_ops.py

Lines changed: 22 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -312,15 +312,15 @@ def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0):
312312
"""Test broadcasting a tensor via cp_broadcast."""
313313
cp_rank = self.mapping.cp_rank
314314
if cp_rank == root:
315-
# Root rank has the tensor to broadcast
315+
# Root rank has the tensor to broadcast.
316316
tensor = root_tensor.cuda()
317317
else:
318-
# Non-root ranks start with zeros
318+
# Non-root ranks start with zeros.
319319
tensor = torch.zeros_like(root_tensor).cuda()
320320

321321
result = self.dist.cp_broadcast(tensor, root=root)
322322

323-
# After broadcast, all CP ranks should have the same tensor
323+
# After broadcast, all CP ranks should have the same tensor.
324324
expected = root_tensor.cuda()
325325
return torch.allclose(result, expected)
326326

@@ -334,12 +334,12 @@ def run_object_broadcast(self, root_obj, root: int = 0):
334334

335335
result = self.dist.cp_broadcast(obj, root=root)
336336

337-
# After broadcast, all CP ranks should have the same object
337+
# After broadcast, all CP ranks should have the same object.
338338
return result == root_obj
339339

340340
def run_tp_cp_broadcast(self, root_obj, root: int = 0):
341341
"""Test broadcasting an object via tp_cp_broadcast."""
342-
# For tp_cp_broadcast, only rank 0 in both TP and CP should have the object
342+
# For tp_cp_broadcast, only rank 0 in both TP and CP should have the object.
343343
tp_rank = self.mapping.tp_rank
344344
cp_rank = self.mapping.cp_rank
345345
if tp_rank == root and cp_rank == root:
@@ -349,7 +349,7 @@ def run_tp_cp_broadcast(self, root_obj, root: int = 0):
349349

350350
result = self.dist.tp_cp_broadcast(obj, root=root)
351351

352-
# After broadcast, all TP and CP ranks should have the same object
352+
# After broadcast, all TP and CP ranks should have the same object.
353353
return result == root_obj
354354

355355

@@ -362,9 +362,9 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
362362
dtype = torch.bfloat16
363363
world_size = 2
364364
tp_size = 1
365-
cp_size = 2 # Enable context parallelism
365+
cp_size = 2 # Enable context parallelism.
366366

367-
# Create tensor to broadcast from root
367+
# Create tensor to broadcast from root.
368368
root_tensor = torch.randn((seq_len, hidden_size), dtype=dtype)
369369

370370
runtime_env = ray.runtime_env.RuntimeEnv()
@@ -385,7 +385,7 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
385385
port = ray.get(remote_tests[0].setup_tcp_store.remote())
386386
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
387387

388-
# Test broadcasting from root=0
388+
# Test broadcasting from root=0.
389389
results = ray.get([
390390
test.run_tensor_broadcast.remote(root_tensor, root=0)
391391
for test in remote_tests
@@ -406,60 +406,21 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
406406
"simple_string",
407407
],
408408
ids=["dict", "list", "string"])
409-
def test_cp_broadcast_object(setup_ray_cluster, test_object):
410-
"""Test TorchDist.cp_broadcast with non-tensor objects."""
411-
world_size = 2
412-
tp_size = 1
413-
cp_size = 2 # Enable context parallelism
414-
415-
runtime_env = ray.runtime_env.RuntimeEnv()
416-
runtime_env["env_vars"] = os.environ.copy()
417-
runtime_env["env_vars"].update({
418-
"TLLM_DISABLE_MPI": "1",
419-
"MASTER_ADDR": "127.0.0.1",
420-
})
421-
422-
remote_tests = []
423-
for rank in range(world_size):
424-
remote_tests.append(
425-
CpBroadcastTest.options(runtime_env=runtime_env).remote(
426-
rank, world_size, tp_size, cp_size))
427-
428-
ray.get([test.__ray_ready__.remote() for test in remote_tests])
429-
430-
port = ray.get(remote_tests[0].setup_tcp_store.remote())
431-
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
432-
433-
# Test broadcasting object from root=0
434-
results = ray.get([
435-
test.run_object_broadcast.remote(test_object, root=0)
436-
for test in remote_tests
437-
])
438-
for r in results:
439-
assert r is True, f"Object broadcast from root=0 failed for {type(test_object)}"
440-
441-
442-
@pytest.mark.gpu2
443-
@pytest.mark.parametrize("test_object", [
444-
{
445-
"key1": "value1",
446-
"key2": [1, 2, 3]
447-
},
448-
["item1", "item2", {
449-
"nested": True
450-
}],
451-
"simple_string",
409+
@pytest.mark.parametrize("broadcast_method", [
410+
"run_object_broadcast",
411+
"run_tp_cp_broadcast",
452412
],
453-
ids=["dict", "list", "string"])
454-
def test_tp_cp_broadcast(setup_ray_cluster, test_object):
455-
"""Test TorchDist.tp_cp_broadcast with various objects.
413+
ids=["cp_broadcast", "tp_cp_broadcast"])
414+
def test_cp_tp_broadcast_object(setup_ray_cluster, test_object,
415+
broadcast_method):
416+
"""Test TorchDist.cp_broadcast and tp_cp_broadcast with non-tensor objects.
456417
457-
This tests the combined TP+CP broadcast which is used when both tensor
458-
and context parallelism are enabled (e.g., helix parallelism).
418+
This tests both cp_broadcast (for context parallelism only) and tp_cp_broadcast
419+
(for combined TP+CP broadcast used in helix parallelism).
459420
"""
460421
world_size = 2
461422
tp_size = 1
462-
cp_size = 2 # Enable context parallelism (tp_cp_broadcast will only do cp_broadcast)
423+
cp_size = 2 # Enable context parallelism.
463424

464425
runtime_env = ray.runtime_env.RuntimeEnv()
465426
runtime_env["env_vars"] = os.environ.copy()
@@ -479,10 +440,10 @@ def test_tp_cp_broadcast(setup_ray_cluster, test_object):
479440
port = ray.get(remote_tests[0].setup_tcp_store.remote())
480441
ray.get([test.setup_distributed_env.remote(port) for test in remote_tests])
481442

482-
# Test tp_cp_broadcast from root=0
443+
# Test broadcasting object from root=0 using the specified method.
483444
results = ray.get([
484-
test.run_tp_cp_broadcast.remote(test_object, root=0)
445+
getattr(test, broadcast_method).remote(test_object, root=0)
485446
for test in remote_tests
486447
])
487448
for r in results:
488-
assert r is True, f"tp_cp_broadcast from root=0 failed for {type(test_object)}"
449+
assert r is True, f"{broadcast_method} from root=0 failed for {type(test_object)}"

0 commit comments

Comments
 (0)