Skip to content
Open
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e90b991
[ROCm] manually pick up fwd native padding support from Meekail's PR
wangye805 Oct 16, 2025
9d02d52
Initial update
Micky774 Oct 16, 2025
81bac35
Updated stride
Micky774 Oct 16, 2025
54ee86a
Corrected typing in allocation portions
Micky774 Oct 16, 2025
47a7cab
Applied Ye's patch
Micky774 Oct 17, 2025
0e0064f
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
945ab5b
[ROCm] jax use runtime segment
wangye805 Oct 21, 2025
579b592
[ROCm] get runtime max_seqlen as well
wangye805 Oct 22, 2025
73247d9
[ROCm] support v2 bwd native padding
wangye805 Oct 22, 2025
7e1c3ef
Updated conversion to include bwd pass
Micky774 Oct 22, 2025
51090d3
Merge branch 'yewang12/te_aiter_native_padding_bwd' into zain/aiter-b…
Micky774 Oct 23, 2025
0e121ba
Added BWD BSHD-->THD conversion and minor logic refactor
Micky774 Oct 23, 2025
734692d
Corrected softmax lse bug
Micky774 Oct 23, 2025
5c24188
Updated logic flow and re-caclulation
Micky774 Oct 23, 2025
b59d466
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
97073fe
Merge branch 'zain/aiter-bwd-bshd-thd' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
f27a99f
Added env var guard
Micky774 Oct 28, 2025
d757aef
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
33c5912
Updated ptr variables and streamlined dispatch
Micky774 Oct 29, 2025
af57290
Added env guard
Micky774 Oct 29, 2025
bc8f4a7
Corrected bshd_to_thd conversion arguments
Micky774 Oct 29, 2025
b7f2cf8
Corrected logical flow
Micky774 Oct 30, 2025
3e48a02
Guarded memset and corrected allocation
Micky774 Nov 5, 2025
b1094c6
Remove V3 API check and guard memsets
Micky774 Nov 5, 2025
c3a0fce
PR comments
Micky774 Nov 6, 2025
9ab8df4
Updated documentation
Micky774 Nov 10, 2025
2adfb6e
PR review reconciliation
Micky774 Nov 10, 2025
bb3868d
Added explicit test
Micky774 Nov 12, 2025
52c8167
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 12, 2025
6206d58
Formatting for bwd debug
Micky774 Nov 13, 2025
0582851
Resolved error when using mixed formats e.g. sbhd_2bshd
Micky774 Nov 14, 2025
78716de
Updated guard on flash-attention forced support
Micky774 Nov 14, 2025
85bb6f6
Added check for SBHD_2BSHD
Micky774 Nov 14, 2025
a12105d
Added guard on dk/dv memset
Micky774 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,10 @@ void fused_attn_ck_fwd_impl(
devPtrSoftmaxLSEWithoutPadding = workspace_next;
workspace_next = static_cast<void *>(static_cast<int8_t *>(workspace_next) + h*max_tokens_q*sizeof(float));
if(is_SBHD && is_padding){
//determine the o buffer based on workspace next section
devPtrOWithoutPadding = workspace_next;
workspace_next = static_cast<void *>(static_cast<int8_t *>(workspace_next) + o_storage_bytes);

//determine q, k ,v buffer based on the workspace next ptr and layout group
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(layout);
//Q ptr always comes at first
Expand Down Expand Up @@ -671,11 +675,10 @@ void fused_attn_ck_fwd_impl(
std::cout << "\nattn_fwd(ck): Converting BSHD to THD\n";
}
}
//determine the o buffer based on workspace next section
devPtrOWithoutPadding = workspace_next;
workspace_next = static_cast<void *>(static_cast<int8_t *>(workspace_next) + o_storage_bytes);
// reset the final results since padded places need to be 0
NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream));
if(is_batch && is_padding){
// reset the final results since padded places need to be 0
NVTE_CHECK_CUDA(cudaMemsetAsync(devPtrO, 0, o_storage_bytes, stream));
}

if (nvte_log_ck_config) {
std::cout<<std::endl<<"attn_fwd(ck): ";
Expand Down Expand Up @@ -964,7 +967,6 @@ void fused_attn_ck_bwd_impl(
// First h*max_tokens_q*sizeof(float) in workspace are for lse-d
void* lse_workspace = workspace;
workspace_next = static_cast<void *>(static_cast<int8_t *>(workspace_next) + h*max_tokens_q*sizeof(float));
NVTE_CHECK_CUDA(cudaMemsetAsync(lse_workspace, 0, h*max_tokens_q*sizeof(float), stream));

// The next section are for dq_acc_ptr
void* dq_acc_ptr = workspace_next;
Expand Down