|
26 | 26 | is_flash_attn_3_available, |
27 | 27 | is_flash_attn_available, |
28 | 28 | is_flash_attn_version, |
| 29 | + is_kernels_available, |
29 | 30 | is_sageattention_available, |
30 | 31 | is_sageattention_version, |
31 | 32 | is_torch_npu_available, |
|
54 | 55 | _CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION) |
55 | 56 | _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION) |
56 | 57 |
|
57 | | -_DEFAULT_HUB_ID_FA3 = "kernels-community/vllm-flash-attn3" |
58 | 58 |
|
59 | 59 | if _CAN_USE_FLASH_ATTN: |
60 | 60 | from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
67 | 67 | from flash_attn_interface import flash_attn_func as flash_attn_3_func |
68 | 68 | from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func |
69 | 69 | else: |
70 | | - try: |
71 | | - from kernels import get_kernel |
| 70 | + flash_attn_3_func = None |
| 71 | + flash_attn_3_varlen_func = None |
72 | 72 |
|
73 | | - vllm_flash_attn3 = get_kernel(_DEFAULT_HUB_ID_FA3) |
74 | | - flash_attn_3_func = vllm_flash_attn3.flash_attn_func |
75 | | - flash_attn_3_varlen_func = vllm_flash_attn3.flash_attn_varlen_func |
76 | | - logger.debug(f"Using Flash Attention 3 from {_DEFAULT_HUB_ID_FA3} using the `kernels` lib.") |
77 | | - except ImportError: |
78 | | - flash_attn_3_func = None |
79 | | - flash_attn_3_varlen_func = None |
| 73 | +if is_kernels_available(): |
| 74 | + from ..utils.kernels_utils import _get_fa3_from_hub |
| 75 | + |
| 76 | + flash_attn_interface_hub = _get_fa3_from_hub() |
| 77 | + if flash_attn_interface_hub is not None: |
| 78 | + flash_attn_3_hub_func = flash_attn_interface_hub.flash_attn_func |
| 79 | + flash_attn_3_varlen_hub_func = flash_attn_interface_hub.flash_attn_varlen_func |
| 80 | + else: |
| 81 | + flash_attn_3_hub_func = None |
| 82 | + flash_attn_3_varlen_hub_func = None |
| 83 | +else: |
| 84 | + flash_attn_3_hub_func = None |
| 85 | + flash_attn_3_varlen_hub_func = None |
80 | 86 |
|
81 | 87 |
|
82 | 88 | if _CAN_USE_SAGE_ATTN: |
@@ -162,6 +168,8 @@ class AttentionBackendName(str, Enum): |
162 | 168 | FLASH_VARLEN = "flash_varlen" |
163 | 169 | _FLASH_3 = "_flash_3" |
164 | 170 | _FLASH_VARLEN_3 = "_flash_varlen_3" |
| 171 | + _FLASH_3_HUB = "_flash_3_hub" |
| 172 | + _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet. |
165 | 173 |
|
166 | 174 | # PyTorch native |
167 | 175 | FLEX = "flex" |
@@ -355,11 +363,20 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None |
355 | 363 | ) |
356 | 364 |
|
357 | 365 | elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]: |
358 | | - if not _CAN_USE_FLASH_ATTN_3 and (flash_attn_3_func is None and flash_attn_3_varlen_func is None): |
| 366 | + if not _CAN_USE_FLASH_ATTN_3: |
359 | 367 | raise RuntimeError( |
360 | 368 | f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source." |
361 | 369 | ) |
362 | 370 |
|
| 371 | + # TODO: add support Hub variant of FA3 varlen later |
| 372 | + elif backend in [AttentionBackendName._FLASH_3_HUB]: |
| 373 | + if not is_kernels_available(): |
| 374 | + raise RuntimeError( |
| 375 | + f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`." |
| 376 | + ) |
| 377 | + elif backend in [AttentionBackendName._FLASH_VARLEN_3_HUB]: |
| 378 | + raise NotImplementedError |
| 379 | + |
363 | 380 | elif backend in [ |
364 | 381 | AttentionBackendName.SAGE, |
365 | 382 | AttentionBackendName.SAGE_VARLEN, |
@@ -523,6 +540,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc |
523 | 540 | return torch.empty_like(query), query.new_empty(lse_shape) |
524 | 541 |
|
525 | 542 |
|
| 543 | +@_custom_op("flash_attn_3_hub_func", mutates_args=(), device_types="cuda") |
| 544 | +def _wrapped_flash_attn_3_hub( |
| 545 | + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| 546 | +) -> Tuple[torch.Tensor, torch.Tensor]: |
| 547 | + out, lse = flash_attn_3_hub_func(query, key, value) |
| 548 | + lse = lse.permute(0, 2, 1) |
| 549 | + return out, lse |
| 550 | + |
| 551 | + |
| 552 | +@_register_fake("flash_attn_3_hub_func") |
| 553 | +def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| 554 | + batch_size, seq_len, num_heads, head_dim = query.shape |
| 555 | + lse_shape = (batch_size, seq_len, num_heads) |
| 556 | + return torch.empty_like(query), query.new_empty(lse_shape) |
| 557 | + |
| 558 | + |
526 | 559 | # ===== Attention backends ===== |
527 | 560 |
|
528 | 561 |
|
@@ -645,34 +678,59 @@ def _flash_attention_3( |
645 | 678 | deterministic: bool = False, |
646 | 679 | return_attn_probs: bool = False, |
647 | 680 | ) -> torch.Tensor: |
648 | | - sig = inspect.signature(flash_attn_3_func) |
649 | | - accepted = set(sig.parameters) |
650 | | - params = { |
651 | | - "q": query, |
652 | | - "k": key, |
653 | | - "v": value, |
654 | | - "softmax_scale": scale, |
655 | | - "causal": is_causal, |
656 | | - "qv": None, |
657 | | - "q_descale": None, |
658 | | - "k_descale": None, |
659 | | - "v_descale": None, |
660 | | - "window_size": window_size, |
661 | | - "attention_chunk": 0, |
662 | | - "softcap": softcap, |
663 | | - "num_splits": 1, |
664 | | - "pack_gqa": None, |
665 | | - "deterministic": deterministic, |
666 | | - "sm_margin": 0, |
667 | | - } |
668 | | - kwargs = {} |
669 | | - for name, value in params.items(): |
670 | | - if name not in accepted: |
671 | | - logger.debug(f"{name} is not accepted by the `flash_attn_3_func` method, so it will be discarded.") |
672 | | - else: |
673 | | - kwargs[name] = value |
| 681 | + out, lse, *_ = flash_attn_3_func( |
| 682 | + q=query, |
| 683 | + k=key, |
| 684 | + v=value, |
| 685 | + softmax_scale=scale, |
| 686 | + causal=is_causal, |
| 687 | + qv=None, |
| 688 | + q_descale=None, |
| 689 | + k_descale=None, |
| 690 | + v_descale=None, |
| 691 | + window_size=window_size, |
| 692 | + attention_chunk=0, |
| 693 | + softcap=softcap, |
| 694 | + num_splits=1, |
| 695 | + pack_gqa=None, |
| 696 | + deterministic=deterministic, |
| 697 | + sm_margin=0, |
| 698 | + ) |
| 699 | + return (out, lse) if return_attn_probs else out |
| 700 | + |
674 | 701 |
|
675 | | - out, lse, *_ = flash_attn_3_func(**kwargs) |
| 702 | +@_AttentionBackendRegistry.register( |
| 703 | + AttentionBackendName._FLASH_3_HUB, |
| 704 | + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| 705 | +) |
| 706 | +def _flash_attention_3_hub( |
| 707 | + query: torch.Tensor, |
| 708 | + key: torch.Tensor, |
| 709 | + value: torch.Tensor, |
| 710 | + scale: Optional[float] = None, |
| 711 | + is_causal: bool = False, |
| 712 | + window_size: Tuple[int, int] = (-1, -1), |
| 713 | + softcap: float = 0.0, |
| 714 | + deterministic: bool = False, |
| 715 | + return_attn_probs: bool = False, |
| 716 | +) -> torch.Tensor: |
| 717 | + out, lse, *_ = flash_attn_3_hub_func( |
| 718 | + q=query, |
| 719 | + k=key, |
| 720 | + v=value, |
| 721 | + softmax_scale=scale, |
| 722 | + causal=is_causal, |
| 723 | + qv=None, |
| 724 | + q_descale=None, |
| 725 | + k_descale=None, |
| 726 | + v_descale=None, |
| 727 | + window_size=window_size, |
| 728 | + softcap=softcap, |
| 729 | + num_splits=1, |
| 730 | + pack_gqa=None, |
| 731 | + deterministic=deterministic, |
| 732 | + sm_margin=0, |
| 733 | + ) |
676 | 734 | return (out, lse) if return_attn_probs else out |
677 | 735 |
|
678 | 736 |
|
|
0 commit comments