Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Nov 3, 2025

Description

Update build system for AOTriton 0.11b and upgrade FWD call to V3 API

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

string(REPLACE ";" "," ARCH_LIST_COMMA_STR "${CMAKE_HIP_ARCHITECTURES}")
set(__AOTRITON_VER "0.11b")
set(__AOTRITON_SHA256
"a2a974e0ad929a5e5827c0f896c59bda4872459cbaf8dd8e0a00407f404491cf" # rocm7.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed.
TE should never download the runtime and must build runtime from source with custom suffix, due to potential conflict with libaotriton shipped by pytorch.

bool is_training, float scaling_factor, float dropout_probability,
NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
int window_size_left, int window_size_right, NVTE_QKV_Layout layout,
Copy link
Contributor

@xinyazhang xinyazhang Nov 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using std::optional
All integer values are valid inputs for AOTriton's SWA (Hence I sometimes refer it as "Generic SWA")

void fused_attn_aotriton_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
bool is_training, float attn_scale, float dropout,
uint64_t window_left, uint64_t window_right,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint64_t?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be int64_t as the default (non-swa) window-sizes from NV upstream is -1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've set it to int32_t since that's ultimately what AOTriton uses.

void fused_attn_aotriton_fwd(
size_t b, size_t h_q, size_t h_kv, size_t max_seqlen_q, size_t max_seqlen_kv, size_t d,
bool is_training, float attn_scale, float dropout,
uint64_t window_left, uint64_t window_right,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint64?

void fused_attn_aotriton_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
bool is_training, float attn_scale, float dropout,
uint64_t window_left, uint64_t window_right,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned above, use std::optiona<int> for window sizes.

varlen_type = 1;
}

int window_left = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we define window_left/right to be aotriton::v3::flash::WindowValue since we already introduce this type in line 244?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, these values are for generic SWA and any integer is valid input.

void fused_attn_aotriton_fwd_qkvpacked(
size_t b, size_t h, size_t max_seqlen, size_t d,
bool is_training, float attn_scale, float dropout,
uint64_t window_left, uint64_t window_right,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be int64_t as the default (non-swa) window-sizes from NV upstream is -1

);
// Next we guard against an initial workspace-allocation which occurs in the
// JAX TE extension. We check for both pointers being null while retaining
// shape data, indicating the use of dummy data in the allocation pass.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it specific to JAX? How does Torch extension behave?

size_t workspace_size = 0;
bool pad_between_seqs = get_pad_between_seqs(
input_cu_seqlens_q,
input_cu_seqlens_kv,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, how seqlens_kv serves as padded seqlens_q?

)
set(__AOTRITON_IMAGE_SHA256_LIST
"3a06a99971dddb7703a30378f1c5d6b41468d926ea51821156d1b6857b985bc4" # amd-gfx942
"27fc21f6761d57987a700436de8cf29cbdd9eeee91318dfed596eeb147d219ad" # amd-gfx950
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do not add other archs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what our support matrix was for archs


//TODO: release after TE integrates swa into AOTriton
bool is_no_mask_window_size= window_size_left == -1 && window_size_right == -1;
bool is_causal_mask_window_size = window_size_left ==-1 && window_size_right ==0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we don't support general SWA feature right now

@Micky774
Copy link
Contributor Author

@xinyazhang could take a second-pass whenever you get the chance? Let me know if there's anything else you'd like addressed, or if you're satisfied with the changes. Thanks!

Copy link
Contributor

@xinyazhang xinyazhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
However, @Micky774 I just rolled out 0.11.1b yesterday which fixes the linker script incompatibility issue and restores Navi31 support (unsure about TE's users about it). You probably want to use them instead ROCm/pytorch#2801

@wenchenvincent
Copy link
Collaborator

@Micky774 Could you address Ilya's comments? Also please merge latest dev to this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants