-
Notifications
You must be signed in to change notification settings - Fork 23
[ROCm] support triton-based flash-attn in TE #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
| else: | ||
| if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version: | ||
| from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd | ||
| if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like flash_attn_cuda_bwd is only used on NV platform and only when explicitly requested by user. So this changes is not needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used in our version as well:
| flash_attn_cuda_bwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It exists in our version but not used. All its usages are under use_FAv2_bwd condition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the use_FAv2_bwd is a parameter in the forward function of class FusedAttnFunc_qkvpacked, FusedAttnFunc_kvpacked and FusedAttnFunc, which user could set to true I guess? See
| use_FAv2_bwd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In normal code-path use_FAv2_bwd is got from NVTE_FUSED_ATTN_USE_FAv2_BWD at
| self.use_FAv2_bwd = os.getenv( |
and it is always set to False for AMD devices there and later at
| use_FAv2_bwd = ( |
Indeed, users may call FusedAttnFunc* classes directly which will bypass a lot of other logic and checks so I wouldn't worry about this parameter specifically. For sanity, this feature may be explicitly disabled for ROCm but anyway keeping things as-is seems less evil than rely on independent project env variables.
|
@wangye805 Could you rebase? |
Description
Enable triton-based flash-attn in TE
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: