Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Oct 28, 2025

Description

Feature update PR which includes several iterative changes for client-driven optimization targets. This PR includes both API changes for CK/AITER as well as changes in internal integration. See the list of changes for specifics.

Note that this will not be ready for merger until ROCm/aiter#1212 is merged in and this PR's AITER commit is updated.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Integrated support for native padding kernels in fwd/bwd
  • Added BSHD + Padding --> THD + Padding conversion mechanism
  • Streamlined memory allocation logic
  • Added runtime max_seqlen calculation gated by new env var NVTE_CK_RUNTIME_MAX_SEQLEN
  • Adds v3_api_check support (temporary)
  • Implements new AITER/CK API
  • Update MQA post-processing kernels
  • Remove pad_between_seqs (need to follow-up with a PR cleaning up test suite for old pad_between_seqs edge-cases)
  • Added NVTE_CK_RUNTIME_NUM_SEGMENTS to guard runtime-calculation of the number of segments in the JAX integration

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, I think we can try to remove all memset except for dq, dq_acc. We can confirm with aiter/ck people

dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it

@wangye805
Copy link
Collaborator

Let's also add how to use the runtime segment/max seqlen in readme under https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#fused-attention-backends-on-rocm. Remind our customers that this will break the cudagraph

@Micky774
Copy link
Contributor Author

Let's also add how to use the runtime segment/max seqlen in readme under https://github.com/ROCm/TransformerEngine?tab=readme-ov-file#fused-attention-backends-on-rocm. Remind our customers that this will break the cudagraph

@wangye805 I've now updated the readme, but let me know if you have specific thoughts on it.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at several unresolved conversation previously

- Updated debug message for BSHD-->THD conversion
- Added env variable to gate FWD output memset for padding
- Removed guards on memsets for d{Q,K,V} matrices
@wenchenvincent
Copy link
Collaborator

@Micky774 Could you rebase/merge latest dev to incorporate the hot fixes for sgpu tests?

@wangye805
Copy link
Collaborator

pytorch test_numerics also shows some fused-attn related failures:
FAILED tests/pytorch/test_numerics.py::test_kv_cache_accuracy[False-FusedAttention-TransformerLayer-sbhd-False-126m-1-dtype1] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [0, 650] with -0.90625 vs 0.5654296875 (diff 1.4716796875).

Not sure whether this is related to our decision to remove memsettings.

@Micky774
Copy link
Contributor Author

pytorch test_numerics also shows some fused-attn related failures: FAILED tests/pytorch/test_numerics.py::test_kv_cache_accuracy[False-FusedAttention-TransformerLayer-sbhd-False-126m-1-dtype1] - AssertionError: Outputs not close enough in tensor at idx=0. Maximum difference at location [0, 650] with -0.90625 vs 0.5654296875 (diff 1.4716796875).

Not sure whether this is related to our decision to remove memsettings.

Those failures were due to a mix of not correctly dispatching to the is_SBHD workflow when dealing with SBHD_2BSHD formats, and miscalculating stride in the case of the same format. Resolved now.

Copy link
Collaborator

@wangye805 wangye805 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those newly added hybrid qkv formats in upstream (NVTE_SBHD_2BSHD, NVTE_BSHD_2SBHD, NVTE_THD_2BSHD, and NVTE_THD_2SBHD): in addition to the SBHD_2BSHD pytest failures, are we able to correctly handle all other 3? Or is there only SBHD_2BSHD pytests now?

NV upstream is separating format and is_ragged on q/kv and do subsequent processings accordingly:

NVTE_QKV_Format q_format = nvte_get_q_format(layout);
NVTE_QKV_Format kv_format = nvte_get_kv_format(layout);
bool is_ragged_q = (q_format == NVTE_QKV_Format::NVTE_THD);
bool is_ragged_kv = (kv_format == NVTE_QKV_Format::NVTE_THD);

Maybe we can try similar technique. If I recall correctly, we need padding/unpadding for just q in SBHD_2BSHD and for just k/v in BSHD_2SBHD.

Or it's okay if you want to leave this for another PR.

By the way, there is an "extra line" comment you may have ignored :-)

max_tokens_q,
devPtrQWithoutPadding,
q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])),
q_stride[1], q_stride[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need std::min(q_stride[0], q_stride[2]) for SBHD_2BSHD formats with padding?

max_tokens_q, max_tokens_kv,
devPtrQWithoutPadding,
q_stride[1], (is_ragged? q_stride[2] : std::min(q_stride[0], q_stride[2])),
q_stride[1], q_stride[0],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar question here for SBHD_2BSHD "hybrid-style" format

dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it

NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrdQ, 0, q_storage_bytes, stream));
// for pad between seqs case, we need to reset all dq, dk, dv
if(pad_between_seqs){
if(is_padding && nvte_ck_zero_out_pad){
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on Wen's analysis, dq, dk, dv requires zero out padding locations for subsequent grad computation. So they should be memset without conditions (nvte_ck_zero_out_pad can only control O tensor)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants