File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
transformer_engine/pytorch Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -148,7 +148,10 @@ def _get_supported_versions(version_min, version_max):
148148 )
149149else :
150150 if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version :
151- from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
151+ if IS_HIP_EXTENSION and (os .getenv ("FLASH_ATTENTION_TRITON_AMD_ENABLE" , "FALSE" )== "TRUE" ):
152+ from flash_attn .flash_attn_triton_amd .interface_fa import varlen_bwd as flash_attn_cuda_bwd
153+ else :
154+ from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
152155 from flash_attn .flash_attn_interface import flash_attn_func , flash_attn_varlen_func
153156 from flash_attn .flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
154157 from flash_attn .flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
You can’t perform that action at this time.
0 commit comments