Skip to content

Commit da74834

Browse files
committed
code formatting fix
1 parent 5579cf1 commit da74834

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,14 +1096,10 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10961096
num_heads = num_heads_local * group_world_size
10971097
seq_len_local = seq_len // group_world_size
10981098

1099-
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1100-
x_temp = (
1101-
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
1102-
.permute(1, 3, 2, 0, 4)
1103-
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
1104-
)
1099+
#B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
1100+
x_temp = x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim).permute(1, 3, 2, 0, 4).reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
11051101

1106-
if group_world_size > 1:
1102+
if group_world_size >1:
11071103
output = _all_to_all_single(x_temp, group)
11081104
else:
11091105
output = x_temp

tests/others/test_unified_sp_attention.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def run(rank, world_size):
2121

2222
q.requires_grad_(True)
2323

24+
25+
2426
pc = ParallelConfig(
2527
context_parallel_config=ContextParallelConfig(
2628
ring_degree=2,

0 commit comments

Comments
 (0)