Skip to content

Commit ac0b468

Browse files
authored
[bugfix] fix flash_attention_2 unavailable error on Ascend NPU (#39844)
1 parent cf243a1 commit ac0b468

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

src/transformers/integrations/npu_flash_attention.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,8 @@ def npu_apply_rotary_emb(x, cos, sin, **kwargs):
267267
sin = sin.unsqueeze(0).unsqueeze(2)
268268

269269
return npu_rotary_mul(x, cos, sin)
270+
271+
272+
def get_npu_flash_attn_funcs():
273+
# return flash attention related functions used for Ascend NPU in order
274+
return npu_flash_attn_func, npu_flash_attn_varlen_func, pad_input, unpad_input, False

src/transformers/modeling_flash_attention_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def fa_peft_integration_check(q, k, v, target_dtype: Optional[torch.dtype] = Non
261261

262262
def _lazy_imports(impl: Optional[str]):
263263
# returns funcs and pad/unpad based on impl
264-
is_fa2 = is_flash_attn_2_available() or is_torch_npu_available()
264+
is_fa2 = is_flash_attn_2_available()
265265
is_fa3 = is_flash_attn_3_available()
266266
if impl == "flash_attention_2" or (impl is None and is_fa2 and not is_fa3):
267267
try:
@@ -299,7 +299,12 @@ def _lazy_imports(impl: Optional[str]):
299299
raise ImportError(
300300
"Failed to import flash attention 2, please install it or use another implementation."
301301
) from e
302-
if impl == "flash_attention_3" or (impl is None and is_fa3):
302+
elif is_torch_npu_available():
303+
# get flash attention related functions from `.integrations.npu_flash_attention` module for Ascend NPU
304+
from .integrations.npu_flash_attention import get_npu_flash_attn_funcs
305+
306+
return get_npu_flash_attn_funcs()
307+
elif impl == "flash_attention_3" or (impl is None and is_fa3):
303308
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func
304309

305310
pad_input, unpad_input = _fa3_pad_input, _fa3_unpad_input

src/transformers/modeling_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2483,8 +2483,12 @@ def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
24832483
preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
24842484
install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
24852485

2486-
# package `flash-attn` can not be installed on Ascend NPU, ignore related validation logi
2487-
if importlib.util.find_spec("flash_attn") is None and not is_torch_npu_available():
2486+
# package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
2487+
if is_torch_npu_available():
2488+
logger.info("Detect using FlashAttention2 on Ascend NPU.")
2489+
return True
2490+
2491+
if importlib.util.find_spec("flash_attn") is None:
24882492
raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
24892493
else:
24902494
# Check FA2 installed version compatibility

0 commit comments

Comments
 (0)