@@ -214,7 +214,7 @@ def all_reduce(
214214
215215 def all_gather (
216216 self , tensor : Union [torch .Tensor , float , str , Any ], group : Optional [Any ] = None
217- ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
217+ ) -> Union [torch .Tensor , float , List [float ], List [str ], List [ Any ] ]:
218218 if not isinstance (tensor , (torch .Tensor , Number , str )):
219219 return self ._do_all_gather_object (tensor , group = group )
220220
@@ -355,11 +355,11 @@ def all_reduce(
355355 return tensor
356356
357357 def all_gather (
358- self , tensor : Union [torch .Tensor , float , str ], group : Optional [Any ] = None
359- ) -> Union [torch .Tensor , float , List [float ], List [str ]]:
358+ self , tensor : Union [torch .Tensor , float , str , Any ], group : Optional [Any ] = None
359+ ) -> Union [torch .Tensor , float , List [float ], List [str ], List [ Any ] ]:
360360 if isinstance (tensor , torch .Tensor ):
361361 return tensor
362- return cast (Union [List [float ], List [str ]], [tensor ])
362+ return cast (Union [List [float ], List [str ], List [ Any ] ], [tensor ])
363363
364364 def broadcast (
365365 self , tensor : Union [torch .Tensor , float , str , None ], src : int = 0 , safe_mode : bool = False
0 commit comments