Skip to content

Commit 773a5c7

Browse files
committed
Remove unused gather function in convtranspose test
1 parent 9569098 commit 773a5c7

File tree

1 file changed

+0
-19
lines changed

1 file changed

+0
-19
lines changed

tests/test_convtranspose.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,6 @@
77
from utils import cleanup_parallel_strategy, fp32_allclose
88

99

10-
def all_gather_vlen(tensor: torch.Tensor, group=None, dim=0) -> list[torch.Tensor]:
11-
"""Gather tensors with the same number of dimensions but different lengths.
12-
13-
Credit: https://stackoverflow.com/a/78934638
14-
"""
15-
world_size = dist.get_world_size(group=group)
16-
# Gather lengths first
17-
shape = torch.as_tensor(tensor.shape, device=tensor.device)
18-
shapes = [torch.empty_like(shape) for _ in range(world_size)]
19-
dist.all_gather(shapes, shape, group=group)
20-
# Gather data
21-
inputs = [tensor] * world_size
22-
outputs = [
23-
torch.empty(*_shape, dtype=tensor.dtype, device=tensor.device)
24-
for _shape in shapes
25-
]
26-
dist.all_to_all(outputs, inputs, group=group)
27-
return torch.cat(outputs, dim=dim)
28-
2910

3011
@pytest.fixture(scope="module")
3112
def parallel_strategy(device: torch.device):

0 commit comments

Comments
 (0)