|
20 | 20 | from torch import nn |
21 | 21 |
|
22 | 22 | from ..image_processor import IPAdapterMaskProcessor |
23 | | -from ..utils import deprecate, logging |
| 23 | +from ..utils import deprecate, logging, is_torch_xla_available |
24 | 24 | from ..utils.import_utils import is_torch_npu_available, is_xformers_available |
25 | 25 | from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph |
26 | 26 |
|
| 27 | +if is_torch_xla_available(): |
| 28 | + from torch_xla.experimental.custom_kernel import flash_attention |
| 29 | + XLA_AVAILABLE = True |
| 30 | +else: |
| 31 | + XLA_AVAILABLE = False |
27 | 32 |
|
28 | 33 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
29 | 34 |
|
@@ -1762,7 +1767,12 @@ def __call__( |
1762 | 1767 | query = apply_rotary_emb(query, image_rotary_emb) |
1763 | 1768 | key = apply_rotary_emb(key, image_rotary_emb) |
1764 | 1769 |
|
1765 | | - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
| 1770 | + if XLA_AVAILABLE: |
| 1771 | + query /= math.sqrt(head_dim) |
| 1772 | + hidden_states = flash_attention(query, key, value, causal=False) |
| 1773 | + else: |
| 1774 | + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
| 1775 | + |
1766 | 1776 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
1767 | 1777 | hidden_states = hidden_states.to(query.dtype) |
1768 | 1778 |
|
@@ -1856,7 +1866,11 @@ def __call__( |
1856 | 1866 | query = apply_rotary_emb(query, image_rotary_emb) |
1857 | 1867 | key = apply_rotary_emb(key, image_rotary_emb) |
1858 | 1868 |
|
1859 | | - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
| 1869 | + if XLA_AVAILABLE: |
| 1870 | + query /= math.sqrt(head_dim) |
| 1871 | + hidden_states = flash_attention(query, key, value) |
| 1872 | + else: |
| 1873 | + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) |
1860 | 1874 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) |
1861 | 1875 | hidden_states = hidden_states.to(query.dtype) |
1862 | 1876 |
|
|
0 commit comments