Skip to content

Commit 3743558

Browse files
committed
bug fixes, added cmnts
1 parent 1fbbf6c commit 3743558

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,22 +1021,23 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10211021
x = _wait_tensor(x)
10221022
return x
10231023

1024-
def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
1024+
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
10251025
group_world_size = funcol.get_world_size(group)
1026-
#dist.get_world_size(group)
10271026

10281027
if scatter_idx == 2 and gather_idx == 1:
10291028
B, S_LOCAL, H, D = x.shape
10301029
S = S_LOCAL * group_world_size
10311030
H_LOCAL = H // group_world_size
10321031

1032+
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
10331033
x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D)
1034-
.permute(0, 2, 1, 3, 4).contiguous()
1034+
.transpose(0, 2).contiguous()
10351035
)
10361036

1037-
out = torch.empty_like(x_temp)
10381037
if group_world_size >1:
1039-
funcol.all_to_all_single(out, x_temp, None, None, group)
1038+
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039+
out = funcol.all_to_all_single(x_temp, None, None, group=group)
1040+
out = _wait_tensor(out)
10401041
else:
10411042
out = x_temp
10421043
out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous()
@@ -1050,19 +1051,20 @@ def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int =
10501051
#
10511052
x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D)
10521053
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
1053-
output = torch.empty_like(x_temp)
1054+
10541055
if group_world_size >1:
1055-
funcol.all_to_all_single(output, x_temp, None, None, group)
1056+
output = funcol.all_to_all_single(x_temp, None, None, group)
1057+
output = _wait_tensor(output)
10561058
else:
10571059
output = x_temp
10581060
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
10591061
output = output.reshape(B, S_LOCAL, H, D)
10601062
return output
10611063
else:
1062-
raise ValueError("Invalid scatter/gather indices for all_to_all_double.")
1064+
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
10631065

10641066

1065-
class SeqAllToAllDouble(torch.autograd.Function):
1067+
class SeqAllToAllDim(torch.autograd.Function):
10661068
@staticmethod
10671069
def forward():
10681070
pass

0 commit comments

Comments
 (0)