diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dd9efa2ff..a58984dd8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -148,7 +148,10 @@ def _get_supported_versions(version_min, version_max): ) else: if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd + if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"): + from flash_attn.flash_attn_triton_amd.interface_fa import varlen_bwd as flash_attn_cuda_bwd + else: + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd