Skip to content

Commit 40b24c6

Browse files
committed
Update
1 parent 9264655 commit 40b24c6

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,9 @@ def _flash_attention(
576576
query: torch.Tensor,
577577
key: torch.Tensor,
578578
value: torch.Tensor,
579+
dropout_p: float = 0.0,
580+
scale: Optional[float] = None,
581+
is_causal: bool = False,
579582
window_size: Tuple[int, int] = (-1, -1),
580583
softcap: float = 0.0,
581584
alibi_slopes: Optional[torch.Tensor] = None,
@@ -586,6 +589,9 @@ def _flash_attention(
586589
q=query,
587590
k=key,
588591
v=value,
592+
dropout_p=dropout_p,
593+
softmax_scale=scale,
594+
causal=is_causal,
589595
window_size=window_size,
590596
softcap=softcap,
591597
alibi_slopes=alibi_slopes,
@@ -748,7 +754,6 @@ def _flash_attention_hub(
748754
query: torch.Tensor,
749755
key: torch.Tensor,
750756
value: torch.Tensor,
751-
dropout_p: float = 0.0,
752757
scale: Optional[float] = None,
753758
is_causal: bool = False,
754759
window_size: Tuple[int, int] = (-1, -1),
@@ -761,7 +766,6 @@ def _flash_attention_hub(
761766
q=query,
762767
k=key,
763768
v=value,
764-
dropout_p=dropout_p,
765769
softmax_scale=scale,
766770
causal=is_causal,
767771
window_size=window_size,

0 commit comments

Comments
 (0)