-
Notifications
You must be signed in to change notification settings - Fork 23
AITER Native Padding Support and BSHD + Padding --> THD + Padding conversion #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from 23 commits
e90b991
9d02d52
81bac35
54ee86a
47a7cab
0e0064f
945ab5b
579b592
73247d9
7e1c3ef
51090d3
0e121ba
734692d
5c24188
b59d466
97073fe
f27a99f
d757aef
33c5912
af57290
bc8f4a7
b7f2cf8
3e48a02
b1094c6
c3a0fce
9ab8df4
2adfb6e
bb3868d
52c8167
6206d58
0582851
78716de
85bb6f6
a12105d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,24 @@ | |
|
|
||
| namespace ck_fused_attn{ | ||
|
|
||
| // TODO: unify with binary search in TE/common/fused_attn(rocm)/util | ||
| // no device std::upper_bound | ||
| // in an increasing array with given size len, search for the index that: | ||
| // array[index] <= target < array[index+1] | ||
| // guaranteed that target >=0 and target <= cu_seqlen[end-1] | ||
| __forceinline__ __device__ int binary_search(int32_t target, const int32_t *array, uint64_t len) { | ||
| int left = 1, right = len - 1; | ||
| while (left < right) { | ||
| int mid = (left + right) / 2; | ||
| if (array[mid] <= target) { | ||
| left = mid + 1; | ||
| } else { | ||
| right = mid; | ||
| } | ||
| } | ||
| return left - 1; | ||
| } | ||
|
|
||
| // define dk_dv_reduce function only for fp16 and bf16 types | ||
| template<typename DataType> | ||
| __global__ void dk_dv_reduce( | ||
|
|
@@ -109,8 +127,9 @@ __global__ void dk_or_dv_reduce( | |
| // define dk_dv_reduce function in THD layout only for fp16 and bf16 types | ||
| template<typename DataType> | ||
| __global__ void dk_dv_reduce_thd( | ||
| uint64_t h, uint64_t hg, uint64_t d, | ||
| const int32_t* total_seqlen_kv_ptr, | ||
| uint64_t b, uint64_t h, uint64_t hg, uint64_t d, | ||
| const int32_t* cu_seqlen_kv_ptr, | ||
| const int32_t* cu_seqlen_kv_padded_ptr, | ||
| const DataType *dk_expanded, | ||
| const DataType *dv_expanded, | ||
| uint64_t stride_h_dkv_expanded, uint64_t stride_s_dkv_expanded, | ||
|
|
@@ -124,11 +143,17 @@ __global__ void dk_dv_reduce_thd( | |
| uint64_t hdim_idx = threadIdx.x; | ||
|
|
||
| assert(hdim_idx<d); | ||
|
|
||
| if(seqlen_idx >= *total_seqlen_kv_ptr){ | ||
| if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ | ||
| return; | ||
| } | ||
|
|
||
| if(cu_seqlen_kv_padded_ptr){ | ||
| uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); | ||
| uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; | ||
| if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ | ||
| return; | ||
| } | ||
| } | ||
| // h guaranteed to be multiples of hg | ||
| uint64_t head_idx_offset = h / hg; | ||
|
|
||
|
|
@@ -164,8 +189,9 @@ __global__ void dk_dv_reduce_thd( | |
| // When d_qk != d_v, we need to reduce dk and dv separately | ||
| template<typename DataType> | ||
| __global__ void dk_or_dv_reduce_thd( | ||
| uint64_t h, uint64_t hg, uint64_t d, | ||
| const int32_t* total_seqlen_kv_ptr, | ||
| uint64_t b, uint64_t h, uint64_t hg, uint64_t d, | ||
| const int32_t* cu_seqlen_kv_ptr, | ||
| const int32_t* cu_seqlen_kv_padded_ptr, | ||
| const DataType *dk_or_dv_expanded, | ||
| uint64_t stride_h_dk_or_dv_expanded, uint64_t stride_s_dk_or_dv_expanded, | ||
| DataType *dk_or_dv, | ||
|
|
@@ -178,10 +204,16 @@ __global__ void dk_or_dv_reduce_thd( | |
|
|
||
| assert(hdim_idx<d); | ||
|
|
||
| if(seqlen_idx >= *total_seqlen_kv_ptr){ | ||
| if(seqlen_idx >= *((cu_seqlen_kv_padded_ptr? cu_seqlen_kv_padded_ptr: cu_seqlen_kv_ptr)+b)){ | ||
| return; | ||
| } | ||
|
|
||
| if(cu_seqlen_kv_padded_ptr){ | ||
| uint64_t seq_idx = binary_search(seqlen_idx, cu_seqlen_kv_padded_ptr, b+1); | ||
| uint64_t unpadded_size = cu_seqlen_kv_ptr[seq_idx+1] - cu_seqlen_kv_ptr[seq_idx]; | ||
| if(seqlen_idx >= cu_seqlen_kv_padded_ptr[seq_idx] + unpadded_size){ | ||
| return; | ||
| } | ||
| } | ||
| // h guaranteed to be multiples of hg | ||
| uint64_t head_idx_offset = h / hg; | ||
|
|
||
|
|
@@ -312,6 +344,8 @@ void log_bwd_config(const char* func_name, | |
| const bool uses_bwd_v3, | ||
| const bool is_v3_atomic_fp32, | ||
| const int how_v3_bf16_cvt, | ||
| const void* cu_seqlen_q_padded_ptr, | ||
| const void* cu_seqlen_kv_padded_ptr, | ||
| const fmha_bwd_args& fmha_args){ | ||
|
|
||
| bool ck_fused_attn_log_config = false; | ||
|
|
@@ -337,6 +371,8 @@ void log_bwd_config(const char* func_name, | |
| std::cout<<"uses_bwd_v3: "<<uses_bwd_v3<<std::endl; | ||
| std::cout<<"is_v3_atomic_fp32: "<<is_v3_atomic_fp32<<std::endl; | ||
| std::cout<<"how_v3_bf16_cvt: "<<how_v3_bf16_cvt<<std::endl; | ||
| std::cout<<"cu_seqlen_q_padded_ptr: "<<cu_seqlen_q_padded_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl; | ||
|
|
||
| // fmha_args debug | ||
| std::cout<<"fmha_args: "<<std::endl; | ||
|
|
@@ -353,9 +389,15 @@ void log_bwd_config(const char* func_name, | |
| std::cout<<"dk_ptr: "<<fmha_args.dk_ptr<<std::endl; | ||
| std::cout<<"dv_ptr: "<<fmha_args.dv_ptr<<std::endl; | ||
| std::cout<<"dbias_ptr: "<<fmha_args.dbias_ptr<<std::endl; | ||
| std::cout<<"dq_acc_ptr: "<<fmha_args.dq_acc_ptr<<std::endl; | ||
|
|
||
| std::cout<<"seqstart_q_ptr: "<<fmha_args.seqstart_q_ptr<<std::endl; | ||
| std::cout<<"seqstart_k_ptr: "<<fmha_args.seqstart_k_ptr<<std::endl; | ||
| std::cout<<"seqlen_q_ptr: "<<fmha_args.seqlen_q_ptr<<std::endl; | ||
| std::cout<<"seqlen_k_ptr: "<<fmha_args.seqlen_k_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_q_ptr: "<<fmha_args.cu_seqlen_q_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_k_ptr: "<<fmha_args.cu_seqlen_k_ptr<<std::endl; | ||
|
|
||
| std::cout<<"seqlen_q: "<<fmha_args.seqlen_q<<std::endl; | ||
| std::cout<<"seqlen_k: "<<fmha_args.seqlen_k<<std::endl; | ||
| std::cout<<"batch: "<<fmha_args.batch<<std::endl; | ||
|
|
@@ -572,9 +614,12 @@ hipError_t ck_attn_bwd( | |
| is_mqa_gqa? dv_expanded_ptr:dv_ptr, | ||
| has_dbias? (bias_shape==BiasShape::kBHSS ? dbias_ptr: dbias_expanded_ptr): nullptr, | ||
| dq_acc_ptr, //dq_acc_buf | ||
| nullptr,//cu_seqlen_q | ||
| nullptr,//cu_seqlen_kv | ||
| nullptr,//seqstart_q_ptr | ||
| nullptr,//seqstart_k_ptr | ||
| nullptr, /* seqlen_q_ptr */ | ||
| nullptr, /* seqlen_k_ptr */ | ||
| nullptr, //cu_seqlen_q_ptr | ||
| nullptr, //cu_seqlen_k_ptr | ||
| shape_seqlen_q, | ||
| shape_seqlen_k, | ||
| batch, | ||
|
|
@@ -633,7 +678,7 @@ hipError_t ck_attn_bwd( | |
| }(); | ||
|
|
||
| // print ck traits and args when needed | ||
| log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); | ||
| log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_type, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, nullptr, nullptr, fmha_args); | ||
wangye805 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if (uses_bwd_v3) | ||
| { | ||
| set_aiter_asm_dir(); | ||
|
|
@@ -650,7 +695,10 @@ hipError_t ck_attn_bwd( | |
| deterministic, | ||
| uses_bwd_v3, | ||
| is_v3_atomic_fp32, | ||
| how_v3_bf16_cvt); | ||
| how_v3_bf16_cvt, | ||
| nullptr, //cu_seqlen_q_padded | ||
| nullptr, //cu_seqlen_kv_padded | ||
| false); //v3 api check | ||
| if(average_runtime < 0){ | ||
| //TODO: better error out system | ||
| throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); | ||
|
|
@@ -784,6 +832,7 @@ hipError_t ck_attn_varlen_bwd( | |
| const void* v_ptr, | ||
| uint64_t stride_h_v, uint64_t stride_s_v, | ||
| const void* cu_seqlen_q_ptr, const void* cu_seqlen_kv_ptr, | ||
| const void* cu_seqlen_q_padded_ptr, const void* cu_seqlen_kv_padded_ptr, | ||
| const void* o_ptr, | ||
| uint64_t stride_h_o, uint64_t stride_s_o, | ||
| const void* lse_thd_ptr, | ||
|
|
@@ -809,6 +858,7 @@ hipError_t ck_attn_varlen_bwd( | |
| bool uses_bwd_v3, | ||
| bool is_v3_atomic_fp32, | ||
| int how_v3_bf16_cvt, | ||
| bool is_v3_api_check, | ||
| hipStream_t stream){ | ||
|
|
||
| bool has_dropout = (dropout_probability > 0.f); | ||
|
|
@@ -915,11 +965,14 @@ hipError_t ck_attn_varlen_bwd( | |
| dq_ptr, | ||
| is_mqa_gqa? dk_expanded_ptr:dk_ptr, | ||
| is_mqa_gqa? dv_expanded_ptr:dv_ptr, | ||
| nullptr, | ||
| nullptr, //dbias_ptr | ||
| dq_acc_ptr, //dq_acc_buf | ||
| cu_seqlen_q_ptr,//cu_seqlen_q | ||
| cu_seqlen_kv_ptr,//cu_seqlen_kv | ||
| cu_seqlen_q_padded_ptr==nullptr? cu_seqlen_q_ptr: cu_seqlen_q_padded_ptr, //seqstart_q_ptr | ||
| cu_seqlen_kv_padded_ptr==nullptr? cu_seqlen_kv_ptr: cu_seqlen_kv_padded_ptr, //seqstart_k_ptr | ||
| nullptr, /* seqlen_q_ptr */ | ||
| nullptr, /* seqlen_k_ptr */ | ||
| cu_seqlen_q_ptr, //cu_seqlen_q_ptr | ||
| cu_seqlen_kv_ptr, //cu_seqlen_k_ptr | ||
| max_seqlen_q, //seqlen_q, unused in group mode | ||
| max_seqlen_k, //seqlen_kv, unused in group mode | ||
| batch, | ||
|
|
@@ -977,26 +1030,44 @@ hipError_t ck_attn_varlen_bwd( | |
| std::pair<const void*, const void*>{philox_seed_ptr, philox_offset_ptr}}; | ||
| }(); | ||
|
|
||
| // modify the max_seqlen_q for better performance in 0-length cases | ||
| // lse_thd_ptr used as buffer | ||
| if(const char* env_p = std::getenv("NVTE_CK_RUNTIME_MAX_SEQLEN")) { | ||
| if(std::string(env_p) == "1" && !is_v3_api_check){ | ||
wangye805 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if(ck_fused_attn_log_config){ | ||
| std::cout << "attn_bwd(ck): Enabling runtime max_seqlen calculation for small seqlen optimization."; | ||
| } | ||
| fmha_args.max_seqlen_q = get_runtime_max_seqlen(b, cu_seqlen_q_ptr, nullptr, lse_workspace_ptr, stream); | ||
| fmha_args.max_seqlen_k = get_runtime_max_seqlen(b, cu_seqlen_kv_ptr, nullptr, lse_workspace_ptr, stream); | ||
| } | ||
| } | ||
|
|
||
| // print ck traits and args when needed | ||
| log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, fmha_args); | ||
| log_bwd_config(__FUNCTION__, data_type_str, is_group_mode, mask_type, bias_enum::no_bias, has_dbias, has_dropout, s_randval, deterministic, uses_bwd_v3, is_v3_atomic_fp32, how_v3_bf16_cvt, cu_seqlen_q_padded_ptr, cu_seqlen_kv_padded_ptr, fmha_args); | ||
| if (uses_bwd_v3) | ||
| { | ||
| set_aiter_asm_dir(); | ||
| } | ||
|
|
||
| float average_runtime = aiter::mha_bwd(fmha_args, | ||
| stream_config, | ||
| data_type_str, | ||
| is_group_mode, | ||
| mask_type, | ||
| bias_enum::no_bias, | ||
| has_dbias, | ||
| s_randval, | ||
| deterministic, | ||
| uses_bwd_v3, | ||
| is_v3_atomic_fp32, | ||
| how_v3_bf16_cvt); | ||
| if(average_runtime < 0){ | ||
| float average_runtime_or_v3_check_status = aiter::mha_bwd(fmha_args, | ||
wangye805 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| stream_config, | ||
| data_type_str, | ||
| is_group_mode, | ||
| mask_type, | ||
| bias_enum::no_bias, | ||
| has_dbias, | ||
| s_randval, | ||
| deterministic, | ||
| uses_bwd_v3, | ||
| is_v3_atomic_fp32, | ||
| how_v3_bf16_cvt, | ||
| uses_bwd_v3? cu_seqlen_q_padded_ptr: nullptr, | ||
| uses_bwd_v3? cu_seqlen_kv_padded_ptr: nullptr, | ||
| is_v3_api_check); | ||
| if(is_v3_api_check){ | ||
| return (hipError_t)(average_runtime_or_v3_check_status > 0); | ||
| } | ||
| if(average_runtime_or_v3_check_status < 0){ | ||
| //TODO: better error out system | ||
| throw std::runtime_error("fused attn configs not supported in ck_fused_attn bwd pass."); | ||
| } | ||
|
|
@@ -1006,6 +1077,8 @@ hipError_t ck_attn_varlen_bwd( | |
| dim3 block(d_qk); | ||
| if (ck_fused_attn_log_config){ | ||
| std::cout<<std::endl<<"run dk_dv_reduce_thd: "<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl; | ||
| std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl; | ||
| std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl; | ||
| std::cout<<"stride_h_dkv_expanded: "<<stride_h_dk_expanded<<std::endl; | ||
|
|
@@ -1018,8 +1091,9 @@ hipError_t ck_attn_varlen_bwd( | |
| CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, | ||
| hipLaunchKernelGGL( | ||
| dk_dv_reduce_thd<CK_TILE_TYPE>, grid, block, 0, stream, | ||
| h, hg, d_qk, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b, | ||
| b, h, hg, d_qk, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr), | ||
| static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr), | ||
| static_cast<CK_TILE_TYPE*>(dk_expanded_ptr), | ||
| static_cast<CK_TILE_TYPE*>(dv_expanded_ptr), | ||
| stride_h_dk_expanded, stride_s_dk_expanded, | ||
|
|
@@ -1030,6 +1104,8 @@ hipError_t ck_attn_varlen_bwd( | |
| dim3 block_dk(d_qk); | ||
| if (ck_fused_attn_log_config){ | ||
| std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dk: "<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl; | ||
| std::cout<<"dk_expanded_ptr: "<<dk_expanded_ptr<<std::endl; | ||
| std::cout<<"stride_h_dk_expanded: "<<stride_h_dk_expanded<<std::endl; | ||
| std::cout<<"stride_s_dk_expanded: "<<stride_s_dk_expanded<<std::endl; | ||
|
|
@@ -1040,8 +1116,9 @@ hipError_t ck_attn_varlen_bwd( | |
| CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, | ||
| hipLaunchKernelGGL( | ||
| dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dk, 0, stream, | ||
| h, hg, d_qk, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b, | ||
| b, h, hg, d_qk, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have gqa/mqa + MLA testcases w and w/ padding? If not, can we create those to verify this flow is actually working
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will work on trying to add one in the JAX side -- for now I've added one on the TE side that isn't able to run due to too few backends supporting it, but that may change e.g. as we update AOTriton
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. Then let's skip the pytorch side gqa/mqa + MLA test for now. You can put a to-do here and add it later when other backends support it |
||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr), | ||
| static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr), | ||
| static_cast<CK_TILE_TYPE*>(dk_expanded_ptr), | ||
| stride_h_dk_expanded, stride_s_dk_expanded, | ||
| static_cast<CK_TILE_TYPE*>(dk_ptr), | ||
|
|
@@ -1050,6 +1127,8 @@ hipError_t ck_attn_varlen_bwd( | |
| dim3 block_dv(d_v); | ||
| if (ck_fused_attn_log_config){ | ||
| std::cout<<std::endl<<"run dk_or_dv_reduce_thd on dv: "<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_ptr: "<<cu_seqlen_kv_ptr<<std::endl; | ||
| std::cout<<"cu_seqlen_kv_padded_ptr: "<<cu_seqlen_kv_padded_ptr<<std::endl; | ||
| std::cout<<"dv_expanded_ptr: "<<dv_expanded_ptr<<std::endl; | ||
| std::cout<<"stride_h_dv_expanded: "<<stride_h_dv_expanded<<std::endl; | ||
| std::cout<<"stride_s_dv_expanded: "<<stride_s_dv_expanded<<std::endl; | ||
|
|
@@ -1060,8 +1139,9 @@ hipError_t ck_attn_varlen_bwd( | |
| CK_FUSED_ATTN_TYPE_SWITCH_16BIT(dtype, CK_TILE_TYPE, | ||
| hipLaunchKernelGGL( | ||
| dk_or_dv_reduce_thd<CK_TILE_TYPE>, grid, block_dv, 0, stream, | ||
| h, hg, d_v, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr)+b, | ||
| b, h, hg, d_v, | ||
| static_cast<const int32_t*>(cu_seqlen_kv_ptr), | ||
| static_cast<const int32_t*>(cu_seqlen_kv_padded_ptr), | ||
| static_cast<CK_TILE_TYPE*>(dv_expanded_ptr), | ||
| stride_h_dv_expanded, stride_s_dv_expanded, | ||
| static_cast<CK_TILE_TYPE*>(dv_ptr), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.