Skip to content

Commit cca5381

Browse files
committed
add ulysses backward
1 parent 27e1d27 commit cca5381

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,33 @@ def backward(
10401040
grad_out: torch.Tensor,
10411041
*args,
10421042
):
1043-
raise NotImplementedError("Backward pass is not implemented for TemplatedUlyssesAttention.")
1043+
parallel_config = _AttentionBackendRegistry._parallel_config
1044+
ulysses_mesh = parallel_config._ulysses_mesh
1045+
world_size = parallel_config.ulysses_degree
1046+
group = ulysses_mesh.get_group()
1047+
1048+
B, S_LOCAL, H, D = grad_out.shape
1049+
H_LOCAL = H // world_size
1050+
1051+
grad_out = grad_out.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
1052+
grad_out = _wait_tensor(funcol.all_to_all_single(grad_out, None, None, group=group))
1053+
grad_out = grad_out.flatten(0, 1).permute(1, 0, 2, 3).contiguous()
1054+
1055+
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.backward_op(ctx.op_ctx, grad_out)
1056+
1057+
grad_query, grad_key, grad_value = (
1058+
x.reshape(B, world_size, S_LOCAL, H_LOCAL, D).permute(1, 3, 0, 2, 4).contiguous()
1059+
for x in (grad_query_op, grad_key_op, grad_value_op)
1060+
)
1061+
grad_query, grad_key, grad_value = (
1062+
_wait_tensor(funcol.all_to_all_single(x, None, None, group=group))
1063+
for x in (grad_query, grad_key, grad_value)
1064+
)
1065+
grad_query, grad_key, grad_value = (
1066+
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
1067+
)
1068+
1069+
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
10441070

10451071

10461072
def _templated_context_parallel_attention(

0 commit comments

Comments
 (0)