@@ -312,15 +312,15 @@ def run_tensor_broadcast(self, root_tensor: torch.Tensor, root: int = 0):
312312 """Test broadcasting a tensor via cp_broadcast."""
313313 cp_rank = self .mapping .cp_rank
314314 if cp_rank == root :
315- # Root rank has the tensor to broadcast
315+ # Root rank has the tensor to broadcast.
316316 tensor = root_tensor .cuda ()
317317 else :
318- # Non-root ranks start with zeros
318+ # Non-root ranks start with zeros.
319319 tensor = torch .zeros_like (root_tensor ).cuda ()
320320
321321 result = self .dist .cp_broadcast (tensor , root = root )
322322
323- # After broadcast, all CP ranks should have the same tensor
323+ # After broadcast, all CP ranks should have the same tensor.
324324 expected = root_tensor .cuda ()
325325 return torch .allclose (result , expected )
326326
@@ -334,12 +334,12 @@ def run_object_broadcast(self, root_obj, root: int = 0):
334334
335335 result = self .dist .cp_broadcast (obj , root = root )
336336
337- # After broadcast, all CP ranks should have the same object
337+ # After broadcast, all CP ranks should have the same object.
338338 return result == root_obj
339339
340340 def run_tp_cp_broadcast (self , root_obj , root : int = 0 ):
341341 """Test broadcasting an object via tp_cp_broadcast."""
342- # For tp_cp_broadcast, only rank 0 in both TP and CP should have the object
342+ # For tp_cp_broadcast, only rank 0 in both TP and CP should have the object.
343343 tp_rank = self .mapping .tp_rank
344344 cp_rank = self .mapping .cp_rank
345345 if tp_rank == root and cp_rank == root :
@@ -349,7 +349,7 @@ def run_tp_cp_broadcast(self, root_obj, root: int = 0):
349349
350350 result = self .dist .tp_cp_broadcast (obj , root = root )
351351
352- # After broadcast, all TP and CP ranks should have the same object
352+ # After broadcast, all TP and CP ranks should have the same object.
353353 return result == root_obj
354354
355355
@@ -362,9 +362,9 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
362362 dtype = torch .bfloat16
363363 world_size = 2
364364 tp_size = 1
365- cp_size = 2 # Enable context parallelism
365+ cp_size = 2 # Enable context parallelism.
366366
367- # Create tensor to broadcast from root
367+ # Create tensor to broadcast from root.
368368 root_tensor = torch .randn ((seq_len , hidden_size ), dtype = dtype )
369369
370370 runtime_env = ray .runtime_env .RuntimeEnv ()
@@ -385,7 +385,7 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
385385 port = ray .get (remote_tests [0 ].setup_tcp_store .remote ())
386386 ray .get ([test .setup_distributed_env .remote (port ) for test in remote_tests ])
387387
388- # Test broadcasting from root=0
388+ # Test broadcasting from root=0.
389389 results = ray .get ([
390390 test .run_tensor_broadcast .remote (root_tensor , root = 0 )
391391 for test in remote_tests
@@ -406,60 +406,21 @@ def test_cp_broadcast_tensor(setup_ray_cluster, seq_len, hidden_size):
406406 "simple_string" ,
407407],
408408 ids = ["dict" , "list" , "string" ])
409- def test_cp_broadcast_object (setup_ray_cluster , test_object ):
410- """Test TorchDist.cp_broadcast with non-tensor objects."""
411- world_size = 2
412- tp_size = 1
413- cp_size = 2 # Enable context parallelism
414-
415- runtime_env = ray .runtime_env .RuntimeEnv ()
416- runtime_env ["env_vars" ] = os .environ .copy ()
417- runtime_env ["env_vars" ].update ({
418- "TLLM_DISABLE_MPI" : "1" ,
419- "MASTER_ADDR" : "127.0.0.1" ,
420- })
421-
422- remote_tests = []
423- for rank in range (world_size ):
424- remote_tests .append (
425- CpBroadcastTest .options (runtime_env = runtime_env ).remote (
426- rank , world_size , tp_size , cp_size ))
427-
428- ray .get ([test .__ray_ready__ .remote () for test in remote_tests ])
429-
430- port = ray .get (remote_tests [0 ].setup_tcp_store .remote ())
431- ray .get ([test .setup_distributed_env .remote (port ) for test in remote_tests ])
432-
433- # Test broadcasting object from root=0
434- results = ray .get ([
435- test .run_object_broadcast .remote (test_object , root = 0 )
436- for test in remote_tests
437- ])
438- for r in results :
439- assert r is True , f"Object broadcast from root=0 failed for { type (test_object )} "
440-
441-
442- @pytest .mark .gpu2
443- @pytest .mark .parametrize ("test_object" , [
444- {
445- "key1" : "value1" ,
446- "key2" : [1 , 2 , 3 ]
447- },
448- ["item1" , "item2" , {
449- "nested" : True
450- }],
451- "simple_string" ,
409+ @pytest .mark .parametrize ("broadcast_method" , [
410+ "run_object_broadcast" ,
411+ "run_tp_cp_broadcast" ,
452412],
453- ids = ["dict" , "list" , "string" ])
454- def test_tp_cp_broadcast (setup_ray_cluster , test_object ):
455- """Test TorchDist.tp_cp_broadcast with various objects.
413+ ids = ["cp_broadcast" , "tp_cp_broadcast" ])
414+ def test_cp_tp_broadcast_object (setup_ray_cluster , test_object ,
415+ broadcast_method ):
416+ """Test TorchDist.cp_broadcast and tp_cp_broadcast with non-tensor objects.
456417
457- This tests the combined TP+CP broadcast which is used when both tensor
458- and context parallelism are enabled (e.g., helix parallelism).
418+ This tests both cp_broadcast (for context parallelism only) and tp_cp_broadcast
419+ (for combined TP+CP broadcast used in helix parallelism).
459420 """
460421 world_size = 2
461422 tp_size = 1
462- cp_size = 2 # Enable context parallelism (tp_cp_broadcast will only do cp_broadcast)
423+ cp_size = 2 # Enable context parallelism.
463424
464425 runtime_env = ray .runtime_env .RuntimeEnv ()
465426 runtime_env ["env_vars" ] = os .environ .copy ()
@@ -479,10 +440,10 @@ def test_tp_cp_broadcast(setup_ray_cluster, test_object):
479440 port = ray .get (remote_tests [0 ].setup_tcp_store .remote ())
480441 ray .get ([test .setup_distributed_env .remote (port ) for test in remote_tests ])
481442
482- # Test tp_cp_broadcast from root=0
443+ # Test broadcasting object from root=0 using the specified method.
483444 results = ray .get ([
484- test . run_tp_cp_broadcast .remote (test_object , root = 0 )
445+ getattr ( test , broadcast_method ) .remote (test_object , root = 0 )
485446 for test in remote_tests
486447 ])
487448 for r in results :
488- assert r is True , f"tp_cp_broadcast from root=0 failed for { type (test_object )} "
449+ assert r is True , f"{ broadcast_method } from root=0 failed for { type (test_object )} "
0 commit comments