Skip to content

Commit 1fbbf6c

Browse files
committed
initial all_to_all_double
1 parent aec0804 commit 1fbbf6c

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,11 +1022,54 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
10221022
return x
10231023

10241024
def _all_to_all_double(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
1025-
pass
1025+
group_world_size = funcol.get_world_size(group)
1026+
#dist.get_world_size(group)
1027+
1028+
if scatter_idx == 2 and gather_idx == 1:
1029+
B, S_LOCAL, H, D = x.shape
1030+
S = S_LOCAL * group_world_size
1031+
H_LOCAL = H // group_world_size
1032+
1033+
x_temp = (x.reshape(B, S_LOCAL, group_world_size, H_LOCAL, D)
1034+
.permute(0, 2, 1, 3, 4).contiguous()
1035+
)
1036+
1037+
out = torch.empty_like(x_temp)
1038+
if group_world_size >1:
1039+
funcol.all_to_all_single(out, x_temp, None, None, group)
1040+
else:
1041+
out = x_temp
1042+
out = out.reshape(S, B, H_LOCAL, D).permute(1, 0, 2, 3).contiguous()
1043+
out = out.reshape(B, S, H_LOCAL, D)
1044+
return out
1045+
elif scatter_idx == 1 and gather_idx == 2:
1046+
B, S, H_LOCAL, D = x.shape
1047+
H = H_LOCAL * group_world_size
1048+
S_LOCAL = S // group_world_size
1049+
1050+
#
1051+
x_temp = (x.reshape(B, group_world_size, S_LOCAL, H_LOCAL, D)
1052+
.permute(1, 3, 2, 0, 4).reshape(group_world_size, H_LOCAL, S_LOCAL, B, D))
1053+
output = torch.empty_like(x_temp)
1054+
if group_world_size >1:
1055+
funcol.all_to_all_single(output, x_temp, None, None, group)
1056+
else:
1057+
output = x_temp
1058+
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()
1059+
output = output.reshape(B, S_LOCAL, H, D)
1060+
return output
1061+
else:
1062+
raise ValueError("Invalid scatter/gather indices for all_to_all_double.")
10261063

10271064

10281065
class SeqAllToAllDouble(torch.autograd.Function):
1029-
pass
1066+
@staticmethod
1067+
def forward():
1068+
pass
1069+
1070+
@staticmethod
1071+
def backward():
1072+
pass
10301073

10311074

10321075

0 commit comments

Comments
 (0)