-
Notifications
You must be signed in to change notification settings - Fork 23
AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Micky774
wants to merge
34
commits into
dev
Choose a base branch
from
zain/aiter-native-bshd-thd
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+412
−196
Open
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
e90b991
[ROCm] manually pick up fwd native padding support from Meekail's PR
wangye805 9d02d52
Initial update
Micky774 81bac35
Updated stride
Micky774 54ee86a
Corrected typing in allocation portions
Micky774 47a7cab
Applied Ye's patch
Micky774 0e0064f
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 945ab5b
[ROCm] jax use runtime segment
wangye805 579b592
[ROCm] get runtime max_seqlen as well
wangye805 73247d9
[ROCm] support v2 bwd native padding
wangye805 7e1c3ef
Updated conversion to include bwd pass
Micky774 51090d3
Merge branch 'yewang12/te_aiter_native_padding_bwd' into zain/aiter-b…
Micky774 0e121ba
Added BWD BSHD-->THD conversion and minor logic refactor
Micky774 734692d
Corrected softmax lse bug
Micky774 5c24188
Updated logic flow and re-caclulation
Micky774 b59d466
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 97073fe
Merge branch 'zain/aiter-bwd-bshd-thd' into zain/aiter-native-bshd-thd
Micky774 f27a99f
Added env var guard
Micky774 d757aef
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 33c5912
Updated ptr variables and streamlined dispatch
Micky774 af57290
Added env guard
Micky774 bc8f4a7
Corrected bshd_to_thd conversion arguments
Micky774 b7f2cf8
Corrected logical flow
Micky774 3e48a02
Guarded memset and corrected allocation
Micky774 b1094c6
Remove V3 API check and guard memsets
Micky774 c3a0fce
PR comments
Micky774 9ab8df4
Updated documentation
Micky774 2adfb6e
PR review reconciliation
Micky774 bb3868d
Added explicit test
Micky774 52c8167
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 6206d58
Formatting for bwd debug
Micky774 0582851
Resolved error when using mixed formats e.g. sbhd_2bshd
Micky774 78716de
Updated guard on flash-attention forced support
Micky774 85bb6f6
Added check for SBHD_2BSHD
Micky774 a12105d
Added guard on dk/dv memset
Micky774 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Submodule aiter
updated
28 files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it