|
38 | 38 | from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS |
39 | 39 |
|
40 | 40 |
|
| 41 | +logger = get_logger(__name__) # pylint: disable=invalid-name |
| 42 | + |
41 | 43 | _REQUIRED_FLASH_VERSION = "2.6.3" |
42 | 44 | _REQUIRED_SAGE_VERSION = "2.1.1" |
43 | 45 | _REQUIRED_FLEX_VERSION = "2.5.0" |
|
52 | 54 | _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) |
53 | 55 | _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) |
54 | 56 |
|
| 57 | +_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3" |
55 | 58 |
|
56 | 59 | if _CAN_USE_FLASH_ATTN: |
57 | 60 | from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
64 | 67 | from flash_attn_interface import flash_attn_func as flash_attn_3_func |
65 | 68 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
66 | 69 | else: |
67 | | - flash_attn_3_func = None |
68 | | - flash_attn_3_varlen_func = None |
| 70 | + try: |
| 71 | + from kernels import get_kernel |
| 72 | + |
| 73 | + vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3) |
| 74 | + flash_attn_3_func = vllm_flash_attn3.flash_attn_func |
| 75 | + flash_attn_3_varlen_func = vllm_flash_attn3.flash_attn_varlen_func |
| 76 | + logger.debug(f"Using Flash Attention 3 from {_DEFAULT_HUB_ID_FA3} using the `kernels` lib.") |
| 77 | + except ImportError: |
| 78 | + flash_attn_3_func = None |
| 79 | + flash_attn_3_varlen_func = None |
69 | 80 |
|
70 | 81 |
|
71 | 82 | if _CAN_USE_SAGE_ATTN: |
@@ -132,8 +143,6 @@ def wrap(func): |
132 | 143 | _register_fake = register_fake_no_op |
133 | 144 |
|
134 | 145 |
|
135 | | -logger = get_logger(__name__) # pylint: disable=invalid-name |
136 | | - |
137 | 146 | # TODO(aryan): Add support for the following: |
138 | 147 | # - Sage Attention++ |
139 | 148 | # - block sparse, radial and other attention methods |
|
0 commit comments