Skip to content

Commit da78c5d

Browse files
committed
workaround compilation problems with triton when doing all-to-all
1 parent 2065acc commit da78c5d

File tree

1 file changed

+25
-17
lines changed

1 file changed

+25
-17
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,19 @@ def _wait_tensor(tensor):
856856
return tensor
857857

858858

859+
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
860+
shape = x.shape
861+
# HACK: We need to flatten because despite making tensors contiguous, torch single-file-ization
862+
# to benchmark triton codegen fails somewhere:
863+
# buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
864+
# ValueError: Tensors must be contiguous
865+
x = x.flatten()
866+
x = funcol.all_to_all_single(x, None, None, group)
867+
x = x.reshape(shape)
868+
x = _wait_tensor(x)
869+
return x
870+
871+
859872
class TemplatedRingAttention(torch.autograd.Function):
860873
@staticmethod
861874
def forward(
@@ -1003,28 +1016,26 @@ def forward(
10031016
ctx.backward_op = backward_op
10041017
ctx.op_ctx = torch.autograd.function.FunctionCtx()
10051018

1006-
B, S_LOCAL, H, D = query.shape
1019+
B, S_Q_LOCAL, H, D = query.shape
1020+
_, S_KV_LOCAL, _, _ = key.shape
10071021
H_LOCAL = H // world_size
1008-
query, key, value = (
1009-
x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1010-
for x in (query, key, value)
1011-
)
1012-
query, key, value = (
1013-
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group)) for x in (query, key, value)
1014-
)
1022+
query = query.reshape(B, S_Q_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1023+
key = key.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1024+
value = value.reshape(B, S_KV_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1025+
query, key, value = (_all_to_all_single(x, group) for x in (query, key, value))
10151026
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
10161027

10171028
out = forward_op(ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
10181029
if return_lse:
10191030
out, lse, *_ = out
10201031

1021-
out = out.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
1022-
out = _wait_tensor(funcol.all_to_all_single(out, None, None, group=group))
1032+
out = out.reshape(B, world_size, S_Q_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
1033+
out = _all_to_all_single(out, group)
10231034
out = out.flatten(0, 1).permute(1, 2, 0, 3).contiguous()
10241035

10251036
if return_lse:
1026-
lse = lse.reshape(B, world_size, S_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
1027-
lse = _wait_tensor(funcol.all_to_all_single(lse, None, None, group=group))
1037+
lse = lse.reshape(B, world_size, S_Q_LOCAL, H_LOCAL).permute(1, 3, 0, 2).contiguous()
1038+
lse = _all_to_all_single(lse, group)
10281039
lse = lse.flatten(0, 1).permute(1, 2, 0).contiguous()
10291040
else:
10301041
lse = None
@@ -1046,7 +1057,7 @@ def backward(
10461057
H_LOCAL = H // world_size
10471058

10481059
grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1049-
grad_out = _wait_tensor(funcol.all_to_all_single(grad_out, None, None, group=group))
1060+
grad_out = _all_to_all_single(grad_out, group)
10501061
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
10511062

10521063
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
@@ -1055,10 +1066,7 @@ def backward(
10551066
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
10561067
for x in (grad_query_op, grad_key_op, grad_value_op)
10571068
)
1058-
grad_query, grad_key, grad_value = (
1059-
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group))
1060-
for x in (grad_query, grad_key, grad_value)
1061-
)
1069+
grad_query, grad_key, grad_value = (_all_to_all_single(x, group) for x in (grad_query, grad_key, grad_value))
10621070
grad_query, grad_key, grad_value = (
10631071
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
10641072
)

0 commit comments

Comments
 (0)