Skip to content

Commit 5643bc9

Browse files
Apply style fixes
1 parent 68f4fd5 commit 5643bc9

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1187,10 +1187,14 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
11871187
num_heads = num_heads_local * group_world_size
11881188
seq_len_local = seq_len // group_world_size
11891189

1190-
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1191-
x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1190+
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1191+
x_temp = (
1192+
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
1193+
.permute(1, 3, 2, 0, 4)
1194+
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1195+
)
11921196

1193-
if group_world_size >1:
1197+
if group_world_size > 1:
11941198
output = _all_to_all_single(x_temp, group)
11951199
else:
11961200
output = x_temp

0 commit comments

Comments
 (0)