From ba794e89e22e4f64b771b8727cb95d19b03bd81a Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Thu, 1 May 2025 11:15:25 -0500 Subject: [PATCH] [ROCm] support triton-based flash-attn in TE --- transformer_engine/pytorch/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index dd9efa2ff..a58984dd8 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -148,7 +148,10 @@ def _get_supported_versions(version_min, version_max): ) else: if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: - from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd + if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"): + from flash_attn.flash_attn_triton_amd.interface_fa import varlen_bwd as flash_attn_cuda_bwd + else: + from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd