@@ -107,6 +107,54 @@ def broadcast_scalar(
107107 return tensor .item ()
108108
109109
110+ def broadcast_object (input_object : typing .Any | None , group : ProcessGroup | None , src : int = 0 ) -> typing .Any :
111+ """
112+ Broadcasts a Python object from src rank to all other ranks in the ProcessGroup.
113+ Returns the object on all ranks.
114+ """
115+ assert group is not None
116+
117+ if group .rank () == src :
118+ tensor = _object_to_tensor (input_object )
119+ size = tensor .numel ()
120+ broadcast_tensor = torch .empty (size , dtype = torch .uint8 , device = torch .cuda .current_device ())
121+ broadcast_tensor .copy_ (tensor )
122+ broadcast_scalar (size , torch .int64 , group , src )
123+ broadcast (broadcast_tensor , src , group )
124+ return input_object
125+ else :
126+ size = int (broadcast_scalar (None , torch .int64 , group , src ))
127+ output_tensor = torch .empty (size , dtype = torch .uint8 , device = torch .cuda .current_device ())
128+ broadcast (output_tensor , src , group )
129+ return _tensor_to_object (output_tensor )
130+
131+
132+ def broadcast_optional (tensor : torch .Tensor | None , group : ProcessGroup = None , src : int = 0 ) -> torch .Tensor :
133+ """
134+ Broadcasts an optional tensor of size, shape, and dtype unknown in advance.
135+ Returns the tensor on all ranks or None if no tensor was sent.
136+ """
137+ assert group is not None
138+
139+ if group .rank () == src :
140+ has_tensor = tensor is not None
141+ if has_tensor :
142+ meta = (has_tensor , tensor .shape , tensor .dtype )
143+ else :
144+ meta = (has_tensor , None , None )
145+ broadcast_object (meta , group , src )
146+ if has_tensor :
147+ broadcast (tensor .to (torch .cuda .current_device ()), src , group )
148+ return tensor
149+ else :
150+ has_tensor , shape , dtype = broadcast_object (None , group , src )
151+ if not has_tensor :
152+ return None
153+ output_tensor = torch .empty (shape , dtype = dtype , device = torch .cuda .current_device ())
154+ broadcast (output_tensor , src , group )
155+ return output_tensor
156+
157+
110158def send (tensor : torch .Tensor , dst : int , group : ProcessGroup , async_op = False , tag : int = 0 ) -> Work | None :
111159 assert group is not None
112160 work = group .send ([tensor ], dst , tag )
@@ -186,7 +234,11 @@ def scatter(
186234def _object_to_tensor (obj : typing .Any ) -> torch .Tensor :
187235 f = io .BytesIO ()
188236 pickle .Pickler (f ).dump (obj )
189- return torch .tensor (torch .UntypedStorage .from_buffer (f .getvalue (), dtype = torch .uint8 ), dtype = torch .uint8 )
237+ byte_storage = torch .ByteStorage ._from_buffer (f .getvalue ()) # type: ignore[attr-defined]
238+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
239+ # Otherwise, it will casue 100X slowdown.
240+ # See: https://github.com/pytorch/pytorch/issues/65696
241+ return torch .ByteTensor (byte_storage )
190242
191243
192244def _tensor_to_object (tensor : torch .Tensor ) -> typing .Any :
0 commit comments