-
Notifications
You must be signed in to change notification settings - Fork 23
Update build system for AOTriton 0.11b and upgrade FWD call to V3 API
#360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
| string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}") | ||
| set(__AOTRITON_VER "0.11b") | ||
| set(__AOTRITON_SHA256 | ||
| "a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be removed.
TE should never download the runtime and must build runtime from source with custom suffix, due to potential conflict with libaotriton shipped by pytorch.
| bool is_training, float scaling_factor, float dropout_probability, | ||
| NVTE_QKV_Layout layout, | ||
| NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, | ||
| int window_size_left, int window_size_right, NVTE_QKV_Layout layout, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using std::optional
All integer values are valid inputs for AOTriton's SWA (Hence I sometimes refer it as "Generic SWA")
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
| void fused_attn_aotriton_fwd_qkvpacked( | ||
| size_t b, size_t h, size_t max_seqlen, size_t d, | ||
| bool is_training, float attn_scale, float dropout, | ||
| uint64_t window_left, uint64_t window_right, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint64_t?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be int64_t as the default (non-swa) window-sizes from NV upstream is -1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've set it to int32_t since that's ultimately what AOTriton uses.
| void fused_attn_aotriton_fwd( | ||
| size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d, | ||
| bool is_training, float attn_scale, float dropout, | ||
| uint64_t window_left, uint64_t window_right, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint64?
| void fused_attn_aotriton_fwd_qkvpacked( | ||
| size_t b, size_t h, size_t max_seqlen, size_t d, | ||
| bool is_training, float attn_scale, float dropout, | ||
| uint64_t window_left, uint64_t window_right, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I mentioned above, use std::optiona<int> for window sizes.
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
| varlen_type = 1; | ||
| } | ||
|
|
||
| int window_left = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we define window_left/right to be aotriton::v3::flash::WindowValue since we already introduce this type in line 244?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, these values are for generic SWA and any integer is valid input.
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
| void fused_attn_aotriton_fwd_qkvpacked( | ||
| size_t b, size_t h, size_t max_seqlen, size_t d, | ||
| bool is_training, float attn_scale, float dropout, | ||
| uint64_t window_left, uint64_t window_right, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be int64_t as the default (non-swa) window-sizes from NV upstream is -1
| ); | ||
| // Next we guard against an initial workspace-allocation which occurs in the | ||
| // JAX TE extension. We check for both pointers being null while retaining | ||
| // shape data, indicating the use of dummy data in the allocation pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it specific to JAX? How does Torch extension behave?
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
| size_t workspace_size = 0; | ||
| bool pad_between_seqs = get_pad_between_seqs( | ||
| input_cu_seqlens_q, | ||
| input_cu_seqlens_kv, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, how seqlens_kv serves as padded seqlens_q?
| ) | ||
| set(__AOTRITON_IMAGE_SHA256_LIST | ||
| "3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942 | ||
| "27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do not add other archs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wasn't sure what our support matrix was for archs
|
|
||
| //TODO: release after TE integrates swa into AOTriton | ||
| bool is_no_mask_window_size= window_size_left == -1 && window_size_right == -1; | ||
| bool is_causal_mask_window_size = window_size_left ==-1 && window_size_right ==0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we don't support general SWA feature right now
transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp
Outdated
Show resolved
Hide resolved
|
@xinyazhang could take a second-pass whenever you get the chance? Let me know if there's anything else you'd like addressed, or if you're satisfied with the changes. Thanks! |
xinyazhang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
However, @Micky774 I just rolled out 0.11.1b yesterday which fixes the linker script incompatibility issue and restores Navi31 support (unsure about TE's users about it). You probably want to use them instead ROCm/pytorch#2801
|
@Micky774 Could you address Ilya's comments? Also please merge latest dev to this PR. |
Description
Update build system for AOTriton
0.11band upgrade FWD call to V3 APIFixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: