Skip to content

Conversation

@wangye805
Copy link
Collaborator

Description

Enable triton-based flash-attn in TE

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@wangye805 wangye805 marked this pull request as ready for review May 1, 2025 16:34
else:
if _flash_attn_version_required <= _flash_attn_version <= _flash_attn_max_version:
from flash_attn_2_cuda import varlen_bwd as flash_attn_cuda_bwd
if IS_HIP_EXTENSION and (os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE")=="TRUE"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like flash_attn_cuda_bwd is only used on NV platform and only when explicitly requested by user. So this changes is not needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in our version as well:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It exists in our version but not used. All its usages are under use_FAv2_bwd condition

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the use_FAv2_bwd is a parameter in the forward function of class FusedAttnFunc_qkvpacked, FusedAttnFunc_kvpacked and FusedAttnFunc, which user could set to true I guess? See

as an example. Then this flash_attn_cuda_bwd could be used somewhere else by customers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In normal code-path use_FAv2_bwd is got from NVTE_FUSED_ATTN_USE_FAv2_BWD at

self.use_FAv2_bwd = os.getenv(

and it is always set to False for AMD devices there and later at
Indeed, users may call FusedAttnFunc* classes directly which will bypass a lot of other logic and checks so I wouldn't worry about this parameter specifically. For sanity, this feature may be explicitly disabled for ROCm but anyway keeping things as-is seems less evil than rely on independent project env variables.

@wenchenvincent
Copy link
Collaborator

@wangye805 Could you rebase?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants