Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,10 @@ def _get_supported_versions(version_min, version_max):
)
else:
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like flash_attn_cuda_bwd is only used on NV platform and only when explicitly requested by user. So this changes is not needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in our version as well:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It exists in our version but not used. All its usages are under use_FAv2_bwd condition

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the use_FAv2_bwd is a parameter in the forward function of class FusedAttnFunc_qkvpacked, FusedAttnFunc_kvpacked and FusedAttnFunc, which user could set to true I guess? See

as an example. Then this flash_attn_cuda_bwd could be used somewhere else by customers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In normal code-path use_FAv2_bwd is got from NVTE_FUSED_ATTN_USE_FAv2_BWD at

self.use_FAv2_bwd = os.getenv(

and it is always set to False for AMD devices there and later at
Indeed, users may call FusedAttnFunc* classes directly which will bypass a lot of other logic and checks so I wouldn't worry about this parameter specifically. For sanity, this feature may be explicitly disabled for ROCm but anyway keeping things as-is seems less evil than rely on independent project env variables.

from flash_attn.flash_attn_triton_amd.interface_fa import varlen_bwd as flash_attn_cuda_bwd
else:
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd
from flash_attn.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd
Expand Down