Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
e90b991
[ROCm] manually pick up fwd native padding support from Meekail's PR
wangye805 Oct 16, 2025
9d02d52
Initial update
Micky774 Oct 16, 2025
81bac35
Updated stride
Micky774 Oct 16, 2025
54ee86a
Corrected typing in allocation portions
Micky774 Oct 16, 2025
47a7cab
Applied Ye's patch
Micky774 Oct 17, 2025
0e0064f
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
945ab5b
[ROCm] jax use runtime segment
wangye805 Oct 21, 2025
579b592
[ROCm] get runtime max_seqlen as well
wangye805 Oct 22, 2025
73247d9
[ROCm] support v2 bwd native padding
wangye805 Oct 22, 2025
7e1c3ef
Updated conversion to include bwd pass
Micky774 Oct 22, 2025
51090d3
Merge branch 'yewang12/te_aiter_native_padding_bwd' into zain/aiter-b…
Micky774 Oct 23, 2025
0e121ba
Added BWD BSHD-->THD conversion and minor logic refactor
Micky774 Oct 23, 2025
734692d
Corrected softmax lse bug
Micky774 Oct 23, 2025
5c24188
Updated logic flow and re-caclulation
Micky774 Oct 23, 2025
b59d466
[ROCm] manually pick Meekail's PR to support native padding for bwd
wangye805 Oct 20, 2025
97073fe
Merge branch 'zain/aiter-bwd-bshd-thd' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
f27a99f
Added env var guard
Micky774 Oct 28, 2025
d757aef
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Oct 28, 2025
33c5912
Updated ptr variables and streamlined dispatch
Micky774 Oct 29, 2025
af57290
Added env guard
Micky774 Oct 29, 2025
bc8f4a7
Corrected bshd_to_thd conversion arguments
Micky774 Oct 29, 2025
b7f2cf8
Corrected logical flow
Micky774 Oct 30, 2025
3e48a02
Guarded memset and corrected allocation
Micky774 Nov 5, 2025
b1094c6
Remove V3 API check and guard memsets
Micky774 Nov 5, 2025
c3a0fce
PR comments
Micky774 Nov 6, 2025
9ab8df4
Updated documentation
Micky774 Nov 10, 2025
2adfb6e
PR review reconciliation
Micky774 Nov 10, 2025
bb3868d
Added explicit test
Micky774 Nov 12, 2025
52c8167
Merge branch 'dev' into zain/aiter-native-bshd-thd
Micky774 Nov 12, 2025
6206d58
Formatting for bwd debug
Micky774 Nov 13, 2025
0582851
Resolved error when using mixed formats e.g. sbhd_2bshd
Micky774 Nov 14, 2025
78716de
Updated guard on flash-attention forced support
Micky774 Nov 14, 2025
85bb6f6
Added check for SBHD_2BSHD
Micky774 Nov 14, 2025
a12105d
Added guard on dk/dv memset
Micky774 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 60 files
+7 −2 .github/workflows/aiter-release.yaml
+1 −1 3rdparty/ck_helper/ck/config.h
+1 −1 3rdparty/composable_kernel
+702 −666 aiter/configs/a8w8_bpreshuffle_tuned_gemm.csv
+318 −281 aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv
+25 −11 aiter/configs/asm_a8w8_gemm.csv
+5 −13 aiter/jit/core.py
+1 −1 aiter/jit/optCompilerConfig.json
+57 −37 aiter/ops/gemm_op_a8w8.py
+18 −8 aiter/ops/mha.py
+48 −14 aiter/ops/triton/_triton_kernels/lean_atten.py
+86 −9 aiter/ops/triton/lean_atten.py
+9 −3 aiter/utility/base_tuner.py
+7 −6 csrc/ck_gemm_a8w8_bpreshuffle/README.md
+7 −4 csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.cu
+203 −55 csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py
+6 −2 csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py
+9 −12 csrc/include/asm_gemm_a8w8.h
+5 −5 csrc/include/mha_bwd.h
+61 −60 csrc/include/rocm_ops.hpp
+3 −5 csrc/include/torch/mha_v3_varlen_bwd.h
+5 −2 csrc/include/torch/mha_varlen_bwd.h
+114 −15 csrc/kernels/cache_kernels.cu
+6 −3 csrc/py_itfs_ck/mha_bwd_kernels.cu
+8 −9 csrc/py_itfs_ck/mha_fwd_kernels.cu
+45 −9 csrc/py_itfs_ck/mha_varlen_bwd_kernels.cu
+26 −35 csrc/py_itfs_ck/mha_varlen_fwd_kernels.cu
+10 −9 csrc/py_itfs_cu/asm_gemm_a16w16.cu
+231 −53 csrc/py_itfs_cu/asm_gemm_a8w8.cu
+7 −4 csrc/py_itfs_cu/asm_mha_bwd.cu
+7 −8 csrc/py_itfs_cu/asm_mha_fwd.cu
+43 −11 csrc/py_itfs_cu/asm_mha_varlen_bwd.cu
+7 −8 csrc/py_itfs_cu/asm_mha_varlen_fwd.cu
+0 −1 hsa/gfx942/bf16gemm/bf16gemm_outf32.csv
+ hsa/gfx942/bf16gemm/bf16gemm_outf32_tn_32x64_pf3.co
+ hsa/gfx942/bf16gemm/bf16gemm_outf32_tn_48x64_pf3.co
+ hsa/gfx942/bf16gemm/bf16gemm_outf32_tn_64x64_pf3.co
+ hsa/gfx942/bf16gemm/bf16gemm_outf32_tn_80x64_pf3.co
+ hsa/gfx942/bf16gemm/bf16gemm_outf32_tn_96x64_pf3.co
+17 −8 hsa/gfx942/fmha_v3_bwd/codegen.py
+ hsa/gfx942/i8gemm/I8gemm_bf16_perTokenI8_BpreShuffle_128x128.co
+ hsa/gfx942/i8gemm/I8gemm_bf16_perTokenI8_BpreShuffle_192x128.co
+64 −0 hsa/gfx942/i8gemm/codegen.py
+3 −0 hsa/gfx942/i8gemm/i8gemm_bf16_perTokenI8.csv
+0 −1 hsa/gfx950/bf16gemm/bf16gemm_outf32.csv
+ hsa/gfx950/bf16gemm/bf16gemm_outf32_tn_32x64_pf3.co
+ hsa/gfx950/bf16gemm/bf16gemm_outf32_tn_48x64_pf3.co
+ hsa/gfx950/bf16gemm/bf16gemm_outf32_tn_64x64_pf3.co
+ hsa/gfx950/bf16gemm/bf16gemm_outf32_tn_80x64_pf3.co
+ hsa/gfx950/bf16gemm/bf16gemm_outf32_tn_96x64_pf3.co
+58 −28 hsa/gfx950/fmha_v3_bwd/codegen.py
+76 −0 hsa/gfx950/i8gemm/codegen.py
+6 −3 op_tests/cpp/mha/benchmark_mha_bwd.cpp
+14 −7 op_tests/cpp/mha/benchmark_mha_fwd.cpp
+78 −63 op_tests/op_benchmarks/triton/bench_la.py
+36 −4 op_tests/test_concat_cache_mla.py
+150 −18 op_tests/test_mha_varlen.py
+56 −40 op_tests/triton_tests/test_la.py
+2 −1 requirements.txt
+0 −1 setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ hipError_t ck_attn_fwd(
uint64_t stride_b_o, uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
hipStream_t stream);

hipError_t ck_attn_varlen_fwd(
Expand All @@ -72,6 +73,7 @@ hipError_t ck_attn_varlen_fwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
bool is_training,
float scaling_factor,
float dropout_probability,
Expand All @@ -82,6 +84,8 @@ hipError_t ck_attn_varlen_fwd(
uint64_t stride_h_o, uint64_t stride_s_o,
void* lse_thd_ptr,
bool uses_fwd_v3,
int how_v3_bf16_cvt,
bool is_v3_api_check,
hipStream_t stream);

hipError_t ck_attn_bwd(
Expand Down Expand Up @@ -137,6 +141,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand All @@ -162,6 +167,7 @@ hipError_t ck_attn_varlen_bwd(
bool uses_bwd_v3,
bool is_v3_atomic_fp32,
int how_v3_bf16_cvt,
bool is_v3_api_check,
hipStream_t stream);

}//namespace ck_fused_attn
Expand Down
152 changes: 116 additions & 36 deletions transformer_engine/common/ck_fused_attn/src/ck_fused_attn_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@

namespace ck_fused_attn{

// TODO: unify with binary search in TE/common/fused_attn(rocm)/util
// no device std::upper_bound
// in an increasing array with given size len, search for the index that:
// array[index] <= target < array[index+1]
// guaranteed that target >=0 and target <= cu_seqlen[end-1]
__forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) {
int left = 1, right = len - 1;
while (left < right) {
int mid = (left + right) / 2;
if (array[mid] <= target) {
left = mid + 1;
} else {
right = mid;
}
}
return left - 1;
}

// define dk_dv_reduce function only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce(
Expand Down Expand Up @@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce(
// define dk_dv_reduce function in THD layout only for fp16 and bf16 types
template<typename DataType>
__global__ void dk_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_expanded,
const DataType *dv_expanded,
uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded,
Expand All @@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd(
uint64_t hdim_idx = threadIdx.x;

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd(
// When d_qk != d_v, we need to reduce dk and dv separately
template<typename DataType>
__global__ void dk_or_dv_reduce_thd(
uint64_t h, uint64_t hg, uint64_t d,
const int32_t* total_seqlen_kv_ptr,
uint64_t b, uint64_t h, uint64_t hg, uint64_t d,
const int32_t* cu_seqlen_kv_ptr,
const int32_t* cu_seqlen_kv_padded_ptr,
const DataType *dk_or_dv_expanded,
uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded,
DataType *dk_or_dv,
Expand All @@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd(

assert(hdim_idx<d);

if(seqlen_idx >= *total_seqlen_kv_ptr){
if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){
return;
}

if(cu_seqlen_kv_padded_ptr){
uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1);
uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx];
if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){
return;
}
}
// h guaranteed to be multiples of hg
uint64_t head_idx_offset = h / hg;

Expand Down Expand Up @@ -312,6 +344,8 @@ void log_bwd_config(const char* func_name,
const bool uses_bwd_v3,
const bool is_v3_atomic_fp32,
const int how_v3_bf16_cvt,
const void* cu_seqlen_q_padded_ptr,
const void* cu_seqlen_kv_padded_ptr,
const fmha_bwd_args& fmha_args){

bool ck_fused_attn_log_config = false;
Expand All @@ -337,6 +371,8 @@ void log_bwd_config(const char* func_name,
std::cout<<"uses_bwd_v3: "<<uses_bwd_v3<<std::endl;
std::cout<<"is_v3_atomic_fp32: "<<is_v3_atomic_fp32<<std::endl;
std::cout<<"how_v3_bf16_cvt: "<<how_v3_bf16_cvt<<std::endl;
std::cout<<"cu_seqlen_q_padded_ptr: "<<cu_seqlen_q_padded_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;

// fmha_args debug
std::cout<<"fmha_args: "<<std::endl;
Expand All @@ -353,9 +389,15 @@ void log_bwd_config(const char* func_name,
std::cout<<"dk_ptr: "<<fmha_args.dk_ptr<<std::endl;
std::cout<<"dv_ptr: "<<fmha_args.dv_ptr<<std::endl;
std::cout<<"dbias_ptr: "<<fmha_args.dbias_ptr<<std::endl;
std::cout<<"dq_acc_ptr: "<<fmha_args.dq_acc_ptr<<std::endl;

std::cout<<"seqstart_q_ptr: "<<fmha_args.seqstart_q_ptr<<std::endl;
std::cout<<"seqstart_k_ptr: "<<fmha_args.seqstart_k_ptr<<std::endl;
std::cout<<"seqlen_q_ptr: "<<fmha_args.seqlen_q_ptr<<std::endl;
std::cout<<"seqlen_k_ptr: "<<fmha_args.seqlen_k_ptr<<std::endl;
std::cout<<"cu_seqlen_q_ptr: "<<fmha_args.cu_seqlen_q_ptr<<std::endl;
std::cout<<"cu_seqlen_k_ptr: "<<fmha_args.cu_seqlen_k_ptr<<std::endl;

std::cout<<"seqlen_q: "<<fmha_args.seqlen_q<<std::endl;
std::cout<<"seqlen_k: "<<fmha_args.seqlen_k<<std::endl;
std::cout<<"batch: "<<fmha_args.batch<<std::endl;
Expand Down Expand Up @@ -572,9 +614,12 @@ hipError_t ck_attn_bwd(
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr,
dq_acc_ptr, //dq_acc_buf
nullptr,//cu_seqlen_q
nullptr,//cu_seqlen_kv
nullptr,//seqstart_q_ptr
nullptr,//seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
nullptr, //cu_seqlen_q_ptr
nullptr, //cu_seqlen_k_ptr
shape_seqlen_q,
shape_seqlen_k,
batch,
Expand Down Expand Up @@ -633,7 +678,7 @@ hipError_t ck_attn_bwd(
}();

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, nullptr, nullptr, fmha_args);
if (uses_bwd_v3)
{
set_aiter_asm_dir();
Expand All @@ -650,7 +695,10 @@ hipError_t ck_attn_bwd(
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
how_v3_bf16_cvt,
nullptr, //cu_seqlen_q_padded
nullptr, //cu_seqlen_kv_padded
false); //v3 api check
if(average_runtime < 0){
//TODO: better error out system
throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass.");
Expand Down Expand Up @@ -784,6 +832,7 @@ hipError_t ck_attn_varlen_bwd(
const void* v_ptr,
uint64_t stride_h_v, uint64_t stride_s_v,
const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr,
const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr,
const void* o_ptr,
uint64_t stride_h_o, uint64_t stride_s_o,
const void* lse_thd_ptr,
Expand All @@ -809,6 +858,7 @@ hipError_t ck_attn_varlen_bwd(
bool uses_bwd_v3,
bool is_v3_atomic_fp32,
int how_v3_bf16_cvt,
bool is_v3_api_check,
hipStream_t stream){

bool has_dropout = (dropout_probability > 0.f);
Expand Down Expand Up @@ -915,11 +965,14 @@ hipError_t ck_attn_varlen_bwd(
dq_ptr,
is_mqa_gqa? dk_expanded_ptr:dk_ptr,
is_mqa_gqa? dv_expanded_ptr:dv_ptr,
nullptr,
nullptr, //dbias_ptr
dq_acc_ptr, //dq_acc_buf
cu_seqlen_q_ptr,//cu_seqlen_q
cu_seqlen_kv_ptr,//cu_seqlen_kv
cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr
cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr
nullptr, /* seqlen_q_ptr */
nullptr, /* seqlen_k_ptr */
cu_seqlen_q_ptr, //cu_seqlen_q_ptr
cu_seqlen_kv_ptr, //cu_seqlen_k_ptr
max_seqlen_q, //seqlen_q, unused in group mode
max_seqlen_k, //seqlen_kv, unused in group mode
batch,
Expand Down Expand Up @@ -977,26 +1030,44 @@ hipError_t ck_attn_varlen_bwd(
std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}};
}();

// modify the max_seqlen_q for better performance in 0-length cases
// lse_thd_ptr used as buffer
if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) {
if(std::string(env_p) == "1" && !is_v3_api_check){
if(ck_fused_attn_log_config){
std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization.";
}
fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream);
fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream);
}
}

// print ck traits and args when needed
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args);
log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args);
if (uses_bwd_v3)
{
set_aiter_asm_dir();
}

float average_runtime = aiter::mha_bwd(fmha_args,
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt);
if(average_runtime < 0){
float average_runtime_or_v3_check_status = aiter::mha_bwd(fmha_args,
stream_config,
data_type_str,
is_group_mode,
mask_type,
bias_enum::no_bias,
has_dbias,
s_randval,
deterministic,
uses_bwd_v3,
is_v3_atomic_fp32,
how_v3_bf16_cvt,
uses_bwd_v3? cu_seqlen_q_padded_ptr: nullptr,
uses_bwd_v3? cu_seqlen_kv_padded_ptr: nullptr,
is_v3_api_check);
if(is_v3_api_check){
return (hipError_t)(average_runtime_or_v3_check_status > 0);
}
if(average_runtime_or_v3_check_status < 0){
//TODO: better error out system
throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass.");
}
Expand All @@ -1006,6 +1077,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_dv_reduce_thd: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dkv_expanded: "<<stride_h_dk_expanded<<std::endl;
Expand All @@ -1018,8 +1091,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_dv_reduce_thd<CK_TILE_TYPE>, grid, block, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
Expand All @@ -1030,6 +1104,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dk(d_qk);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dk: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl;
std::cout<<"stride_h_dk_expanded: "<<stride_h_dk_expanded<<std::endl;
std::cout<<"stride_s_dk_expanded: "<<stride_s_dk_expanded<<std::endl;
Expand All @@ -1040,8 +1116,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream,
h, hg, d_qk,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_qk,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it

static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dk_expanded_ptr),
stride_h_dk_expanded, stride_s_dk_expanded,
static_cast<CK_TILE_TYPE*>(dk_ptr),
Expand All @@ -1050,6 +1127,8 @@ hipError_t ck_attn_varlen_bwd(
dim3 block_dv(d_v);
if (ck_fused_attn_log_config){
std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dv: "<<std::endl;
std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl;
std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl;
std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl;
std::cout<<"stride_h_dv_expanded: "<<stride_h_dv_expanded<<std::endl;
std::cout<<"stride_s_dv_expanded: "<<stride_s_dv_expanded<<std::endl;
Expand All @@ -1060,8 +1139,9 @@ hipError_t ck_attn_varlen_bwd(
CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE,
hipLaunchKernelGGL(
dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dv, 0, stream,
h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b,
b, h, hg, d_v,
static_cast<const int32_t*>(cu_seqlen_kv_ptr),
static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr),
static_cast<CK_TILE_TYPE*>(dv_expanded_ptr),
stride_h_dv_expanded, stride_s_dv_expanded,
static_cast<CK_TILE_TYPE*>(dv_ptr),
Expand Down
Loading