Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
gathered_result[idx] = gathered_result[idx][tuple(slice_param)]
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
Loading