Skip to content

Commit 854b5ea

Browse files
authored
Updated typehint for idist.all_gather method (#3089)
Description: - This reflects the fact that idist can use all_gather_object internally
1 parent f26078f commit 854b5ea

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

ignite/distributed/comp_models/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def all_reduce(
214214

215215
def all_gather(
216216
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
217-
) -> Union[torch.Tensor, float, List[float], List[str]]:
217+
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
218218
if not isinstance(tensor, (torch.Tensor, Number, str)):
219219
return self._do_all_gather_object(tensor, group=group)
220220

@@ -355,11 +355,11 @@ def all_reduce(
355355
return tensor
356356

357357
def all_gather(
358-
self, tensor: Union[torch.Tensor, float, str], group: Optional[Any] = None
359-
) -> Union[torch.Tensor, float, List[float], List[str]]:
358+
self, tensor: Union[torch.Tensor, float, str, Any], group: Optional[Any] = None
359+
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
360360
if isinstance(tensor, torch.Tensor):
361361
return tensor
362-
return cast(Union[List[float], List[str]], [tensor])
362+
return cast(Union[List[float], List[str], List[Any]], [tensor])
363363

364364
def broadcast(
365365
self, tensor: Union[torch.Tensor, float, str, None], src: int = 0, safe_mode: bool = False

ignite/distributed/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,8 @@ def all_reduce(
351351

352352

353353
def all_gather(
354-
tensor: Union[torch.Tensor, float, str], group: Optional[Union[Any, List[int]]] = None
355-
) -> Union[torch.Tensor, float, List[float], List[str]]:
354+
tensor: Union[torch.Tensor, float, str, Any], group: Optional[Union[Any, List[int]]] = None
355+
) -> Union[torch.Tensor, float, List[float], List[str], List[Any]]:
356356
"""Helper method to perform all gather operation.
357357
358358
Args:

0 commit comments

Comments
 (0)