Skip to content

Commit 41980ce

Browse files
authored
[bugfix] fix flash-attention2 unavailable error for Ascend NPU (#40151)
* [bugfix] fix flash-attention2 unavailable error for Ascend NPU * remove redundant apply_rotary_emb usage * fix ruff check error * pad_input and unpad_input use same implementation as fa2 * rollback redundant codes * fix ruff check error * optimize fa2 judgement logic
1 parent eba1d62 commit 41980ce

File tree

3 files changed

+10
-30
lines changed

3 files changed

+10
-30
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
if is_torch_npu_available():
22-
from torch_npu import npu_fusion_attention, npu_rotary_mul
22+
from torch_npu import npu_fusion_attention
2323

2424

2525
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
@@ -136,19 +136,3 @@ def npu_flash_attn_varlen_func(
136136
)[0]
137137

138138
return output
139-
140-
141-
def npu_apply_rotary_emb(x, cos, sin, **kwargs):
142-
# cos tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
143-
if len(cos.shape) == 2 and cos.shape[-1] == x.shape[-1] // 2:
144-
cos = cos.repeat(1, 2)
145-
# cos tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
146-
cos = cos.unsqueeze(0).unsqueeze(2)
147-
148-
# sin tensor after chunk should be repeated through chunked dimension to original shape on Ascend NPU
149-
if len(sin.shape) == 2 and sin.shape[-1] == x.shape[-1] // 2:
150-
sin = sin.repeat(1, 2)
151-
# sin tensor with [S,D] shape should be unsqueezed to 4-d tensor with shape [1,S,1,D]
152-
sin = sin.unsqueeze(0).unsqueeze(2)
153-
154-
return npu_rotary_mul(x, cos, sin)

src/transformers/modeling_flash_attention_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,20 @@ def _lazy_imports(implementation: Optional[str]):
7777
"""
7878
is_fa2 = is_flash_attn_2_available()
7979
is_fa3 = is_flash_attn_3_available()
80-
if implementation == "flash_attention_2" or (implementation is None and is_fa2 and not is_fa3):
80+
81+
pad_input, unpad_input = _pad_input, _unpad_input
82+
83+
if (implementation == "flash_attention_2" and is_fa2) or (implementation is None and is_fa2 and not is_fa3):
8184
from flash_attn import flash_attn_func, flash_attn_varlen_func
8285
from flash_attn.bert_padding import pad_input, unpad_input
86+
elif is_torch_npu_available():
87+
# Package `flash-attn` is unavailable on Ascend NPU, which will cause ImportError
88+
# Flash-Attention2 related apis for Ascend NPU must be imported from `.integrations.npu_flash_attention` module
89+
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
90+
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
8391
else:
84-
pad_input, unpad_input = _pad_input, _unpad_input
8592
if implementation == "flash_attention_3" or (implementation is None and is_fa3):
8693
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
87-
elif is_torch_npu_available():
88-
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_func
89-
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_varlen_func
9094
# Kernels fallback
9195
else:
9296
flash_attn_func = getattr(implementation, "flash_attn_func", None)

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
from ...cache_utils import Cache
4545
from ...configuration_utils import PretrainedConfig, layer_type_validation
4646
from ...generation import GenerationMixin
47-
from ...modeling_flash_attention_utils import is_flash_attn_available
4847
from ...modeling_outputs import BaseModelOutput, ModelOutput
4948
from ...modeling_rope_utils import rope_config_validation
5049
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
@@ -58,13 +57,6 @@
5857
from ...utils.hub import cached_file
5958

6059

61-
if is_flash_attn_available():
62-
from ...modeling_flash_attention_utils import apply_rotary_emb, flash_attn_varlen_func
63-
else:
64-
flash_attn_varlen_func = None
65-
apply_rotary_emb = None
66-
67-
6860
logger = logging.get_logger(__name__)
6961

7062

0 commit comments

Comments
 (0)