Skip to content

Commit de5faf8

Browse files
add flash attention.
1 parent c5b4a3c commit de5faf8

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
from torch import nn
2121

2222
from ..image_processor import IPAdapterMaskProcessor
23-
from ..utils import deprecate, logging
23+
from ..utils import deprecate, logging, is_torch_xla_available
2424
from ..utils.import_utils import is_torch_npu_available, is_xformers_available
2525
from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph
2626

27+
if is_torch_xla_available():
28+
from torch_xla.experimental.custom_kernel import flash_attention
29+
XLA_AVAILABLE = True
30+
else:
31+
XLA_AVAILABLE = False
2732

2833
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2934

@@ -1762,7 +1767,12 @@ def __call__(
17621767
query = apply_rotary_emb(query, image_rotary_emb)
17631768
key = apply_rotary_emb(key, image_rotary_emb)
17641769

1765-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1770+
if XLA_AVAILABLE:
1771+
query /= math.sqrt(head_dim)
1772+
hidden_states = flash_attention(query, key, value, causal=False)
1773+
else:
1774+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1775+
17661776
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
17671777
hidden_states = hidden_states.to(query.dtype)
17681778

@@ -1856,7 +1866,11 @@ def __call__(
18561866
query = apply_rotary_emb(query, image_rotary_emb)
18571867
key = apply_rotary_emb(key, image_rotary_emb)
18581868

1859-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1869+
if XLA_AVAILABLE:
1870+
query /= math.sqrt(head_dim)
1871+
hidden_states = flash_attention(query, key, value)
1872+
else:
1873+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
18601874
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
18611875
hidden_states = hidden_states.to(query.dtype)
18621876

0 commit comments

Comments
 (0)