Skip to content

Commit f859fdf

Browse files
committed
refactor; support flash attention 2 with cp
1 parent 7973626 commit f859fdf

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def forward(
571571
value: torch.Tensor,
572572
attn_mask: Optional[torch.Tensor] = None,
573573
dropout_p: float = 0.0,
574-
scale: Optional[float] = None,
575574
is_causal: bool = False,
575+
scale: Optional[float] = None,
576576
enable_gqa: bool = False,
577577
return_lse: bool = False,
578578
):
@@ -653,8 +653,8 @@ def forward(
653653
value: torch.Tensor,
654654
attn_mask: Optional[torch.Tensor] = None,
655655
dropout_p: float = 0.0,
656-
scale: Optional[float] = None,
657656
is_causal: bool = False,
657+
scale: Optional[float] = None,
658658
enable_gqa: bool = False,
659659
return_lse: bool = False,
660660
):
@@ -753,8 +753,8 @@ def forward(
753753
value: torch.Tensor,
754754
attn_mask: Optional[torch.Tensor],
755755
dropout_p: float,
756-
scale: Optional[float],
757756
is_causal: bool,
757+
scale: Optional[float],
758758
enable_gqa: bool,
759759
return_lse: bool,
760760
op: torch.autograd.Function,
@@ -778,7 +778,7 @@ def forward(
778778
value = kv[key.numel() :].reshape_as(value)
779779
next_rank = (next_rank + 1) % world_size
780780

781-
out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True)
781+
out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True)
782782

783783
if parallel_config.convert_to_fp32:
784784
out = out.to(torch.float32)
@@ -813,8 +813,8 @@ def forward(
813813
value: torch.Tensor,
814814
attn_mask: Optional[torch.Tensor],
815815
dropout_p: float,
816-
scale: Optional[float],
817816
is_causal: bool,
817+
scale: Optional[float],
818818
enable_gqa: bool,
819819
return_lse: bool,
820820
op: torch.autograd.Function,
@@ -833,7 +833,7 @@ def forward(
833833
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
834834
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
835835

836-
out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
836+
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
837837
if return_lse:
838838
out, lse, *_ = out
839839

@@ -883,14 +883,14 @@ def _templated_context_parallel_attention(
883883
# TODO: add support for unified attention with ring/ulysses degree both being > 1
884884
if parallel_config.ring_degree > 1:
885885
return TemplatedRingAttention.apply(
886-
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
886+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
887887
)
888888
elif parallel_config.ulysses_degree > 1:
889889
return TemplatedUlyssesAttention.apply(
890-
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
890+
query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, op
891891
)
892892
else:
893-
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
893+
return op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
894894

895895

896896
# ===== Attention backends =====
@@ -905,20 +905,33 @@ def _flash_attention(
905905
key: torch.Tensor,
906906
value: torch.Tensor,
907907
dropout_p: float = 0.0,
908-
scale: Optional[float] = None,
909908
is_causal: bool = False,
909+
scale: Optional[float] = None,
910910
return_lse: bool = False,
911911
) -> torch.Tensor:
912-
out = flash_attn_func(
913-
q=query,
914-
k=key,
915-
v=value,
916-
dropout_p=dropout_p,
917-
softmax_scale=scale,
918-
causal=is_causal,
919-
return_attn_probs=return_lse,
920-
)
921-
return out
912+
parallel_config = _AttentionBackendRegistry._parallel_config
913+
914+
lse = None
915+
if parallel_config is None:
916+
out = flash_attn_func(
917+
q=query,
918+
k=key,
919+
v=value,
920+
dropout_p=dropout_p,
921+
softmax_scale=scale,
922+
causal=is_causal,
923+
return_attn_probs=return_lse,
924+
)
925+
if return_lse:
926+
out, lse, *_ = out
927+
else:
928+
out = _templated_context_parallel_attention(
929+
query, key, value, None, dropout_p, is_causal, scale, False, return_lse, op=_flash_attention_2
930+
)
931+
if return_lse:
932+
out, lse = out
933+
934+
return (out, lse) if return_lse else out
922935

923936

924937
@_AttentionBackendRegistry.register(

0 commit comments

Comments
 (0)