File tree Expand file tree Collapse file tree 2 files changed +67
-0
lines changed Expand file tree Collapse file tree 2 files changed +67
-0
lines changed Original file line number Diff line number Diff line change 2020from torch .distributed import ProcessGroup
2121from torchtnt .utils .distributed import (
2222 _validate_global_rank_world_size ,
23+ all_gather_str ,
2324 all_gather_tensors ,
2425 broadcast_str ,
2526 destroy_process_group ,
@@ -610,3 +611,29 @@ def _test_broadcast_str() -> None:
610611
611612 tc = unittest .TestCase ()
612613 tc .assertEqual (broadcasted_val , "foo" )
614+
615+ @skip_if_not_distributed
616+ def test_all_gather_str (self ) -> None :
617+ backend = "gloo"
618+ if torch .cuda .is_available ():
619+ backend = "nccl"
620+
621+ spawn_multi_process (2 , backend , self ._test_all_gather_str )
622+
623+ @staticmethod
624+ def _test_all_gather_str () -> None :
625+ if torch .cuda .is_available ():
626+ torch .cuda .set_device (dist .get_rank ())
627+
628+ val = None
629+ if dist .get_rank () == 0 :
630+ val = "foo"
631+ else :
632+ val = "barzoo"
633+
634+ # Test case 1: fixed_buffer_size == len(val)
635+ vals = all_gather_str (val )
636+
637+ tc = unittest .TestCase ()
638+ tc .assertEqual (vals [0 ], "foo" )
639+ tc .assertEqual (vals [1 ], "barzoo" )
Original file line number Diff line number Diff line change @@ -772,6 +772,46 @@ def broadcast_str(
772772 return string
773773
774774
775+ def all_gather_str (
776+ val : str , process_group : Optional [dist .ProcessGroup ] = None
777+ ) -> List [str ]:
778+ """
779+ Optimized all gather-ing string without invoking all_gather_object
780+ which is subject to hang issues on nccl.
781+
782+ Args:
783+ val: string to include in all_gather
784+ process_group: the process group to broadcast in
785+
786+ Returns:
787+ List of all strings
788+
789+ Note:
790+ Will construct and use a temporary gloo process group to minimize device to host transfers
791+
792+ TODO: support fixed_buffer_size
793+ """
794+
795+ if not dist .is_available () or not dist .is_initialized ():
796+ return [val ]
797+
798+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
799+
800+ # use gloo so that we avoid gpu->cpu (device to host) transfers
801+ # with get_or_create_gloo_pg(process_group) as gloo_pg:
802+
803+ # Initialize buffer and buffer_length for all ranks
804+ buffer = torch .frombuffer (val .encode ("utf-8" ), dtype = torch .uint8 ).to (device )
805+ # use `all_gather_tensors` which handles all gathering tensors
806+ # of same shape but different lengths (since strings may be different
807+ # length on each rank)
808+ buffer_strings = all_gather_tensors (buffer , group = process_group )
809+
810+ result = [bytes (buffer .tolist ()).decode ("utf-8" ) for buffer in buffer_strings ]
811+
812+ return result
813+
814+
775815@contextmanager
776816def get_or_create_gloo_pg (
777817 candidate_pg : Optional [dist .ProcessGroup ] = None ,
You can’t perform that action at this time.
0 commit comments