We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b14cad4 commit c5b0069Copy full SHA for c5b0069
ssd/utils/distributed_util.py
@@ -71,8 +71,8 @@ def all_gather(data):
71
tensor = torch.ByteTensor(storage).to("cuda")
72
73
# obtain Tensor size of each rank
74
- local_size = torch.IntTensor([tensor.numel()]).to("cuda")
75
- size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
+ local_size = torch.LongTensor([tensor.numel()]).to("cuda")
+ size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)]
76
dist.all_gather(size_list, local_size)
77
size_list = [int(size.item()) for size in size_list]
78
max_size = max(size_list)
0 commit comments