Skip to content

Commit 3fbd1cf

Browse files
committed
switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues
1 parent 5618a7d commit 3fbd1cf

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,8 +1036,8 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10361036

10371037
if group_world_size >1:
10381038
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1039-
out = funcol.all_to_all_single(x_temp, None, None, group=group)
1040-
out = _wait_tensor(out)
1039+
out = _all_to_all_single(x_temp, None, None, group=group)
1040+
#out = _wait_tensor(out)
10411041
else:
10421042
out = x_temp
10431043
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
@@ -1055,8 +1055,8 @@ def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx:
10551055

10561056
if group_world_size >1:
10571057
#maybe here need to use the _all_to_all_single helper to avoid contiguity issues
1058-
output = funcol.all_to_all_single(x_temp, None, None, group)
1059-
output = _wait_tensor(output)
1058+
output = _all_to_all_single(x_temp, None, None, group)
1059+
#output = _wait_tensor(output)
10601060
else:
10611061
output = x_temp
10621062
output = output.reshape(H, S_LOCAL, B, D).transpose(0, 2).contiguous()

tests/others/test_unified_sp_attention.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,8 @@ def dummy_backward_op(ctx, grad_out, *args, **kwargs):
102102
grad_v,
103103
)
104104

105-
attn = TemplatedUnifiedAttention()
106105

107-
out = attn(
108-
None,
106+
out = TemplatedUnifiedAttention(
109107
q, k, v, None,
110108
dropout_p=0.0,
111109
is_causal=False,

0 commit comments

Comments
 (0)