Skip to content

Commit ba794e8

Browse files
committed
[ROCm] support triton-based flash-attn in TE
1 parent f602b36 commit ba794e8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

transformer_engine/pytorch/attention.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def _get_supported_versions(version_min, version_max):
148148
)
149149
else:
150150
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
151-
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
151+
if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"):
152+
from flash_attn.flash_attn_triton_amd.interface_fa import varlen_bwd as flash_attn_cuda_bwd
153+
else:
154+
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
152155
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
153156
from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
154157
from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd

0 commit comments

Comments
 (0)