-
Notifications
You must be signed in to change notification settings - Fork 22
AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ROCm] support v2 bwd native padding
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/ck_fused_attn/src/ck_fused_attn_fwd.cpp
Outdated
Show resolved
Hide resolved
Those failures were due to a mix of not correctly dispatching to the |
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?
NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:
TransformerEngine/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Lines 79 to 82 in 32e2d1d
| NVTE_QKV_Format q_format = nvte_get_q_format(layout); | |
| NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); | |
| bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD); | |
| bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD); |
Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.
Or it's okay if you want to leave this for another PR.
By the way, there is an "extra line" comment you may have ignored :-)
|
In fact, I saw some level 3 pytorch cp pytest failures by run level 3 ci locally: Attached you can find the detailed log |
658c105 to
871cb4e
Compare
|
Manual CI runs in local MI300X machine: |
Description
Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.
Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
max_seqlencalculation gated by new env varNVTE_CK_RUNTIME_MAX_SEQLENv3_api_checksupport (temporary)pad_between_seqs(need to follow-up with a PR cleaning up test suite for oldpad_between_seqsedge-cases)NVTE_CK_RUNTIME_NUM_SEGMENTSto guard runtime-calculation of the number of segments in the JAX integrationChecklist: