Skip to content

Commit f5e24a5

Browse files
committed
cann: update the alibi with max_bias
1 parent 8a902b9 commit f5e24a5

File tree

1 file changed

+103
-5
lines changed

1 file changed

+103
-5
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2690,10 +2690,6 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
26902690
memcpy(&maxBias, (float*)dst->op_params + 1, sizeof(float));
26912691
memcpy(&logitSoftcap, (float*)dst->op_params + 2, sizeof(float));
26922692

2693-
// if(logitSoftcap != 0.0f){
2694-
// // call the non-fa implementation
2695-
// }else{
2696-
26972693
size_t faElemSize = sizeof(uint16_t);
26982694
auto faDataType = ACL_FLOAT16; //ACL_BF16;
26992695

@@ -2825,6 +2821,108 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
28252821
#endif
28262822
ggml_cann_release_resources(ctx, acl_mask_f16_trunc_tensor);
28272823
}
2824+
2825+
if(maxBias != 0.0f){
2826+
// alibi
2827+
const int64_t ne2_ne3 = src0->ne[2] * src0->ne[3];
2828+
const int64_t n_head = src0->ne[2];
2829+
const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
2830+
float m0 = powf(2.0f, -(maxBias) / n_heads_log2_floor);
2831+
float m1 = powf(2.0f, -(maxBias / 2.0f) / n_heads_log2_floor);
2832+
// init arange
2833+
ggml_cann_pool_alloc arange_allocator(ctx.pool(),
2834+
ne2_ne3 * faElemSize);
2835+
void* tmp_arange_buffer = arange_allocator.get();
2836+
2837+
// arange1: [1, ..., n_heads_log2_floor+1)
2838+
float start = 1;
2839+
float stop = n_heads_log2_floor + 1;
2840+
float step = 1;
2841+
int64_t n_elements_arange = n_heads_log2_floor;
2842+
2843+
int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
2844+
size_t tmp_arange1_nb[] = {faElemSize};
2845+
aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
2846+
tmp_arange_buffer, faDataType, faElemSize,
2847+
tmp_arange1_ne, tmp_arange1_nb,
2848+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2849+
2850+
aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
2851+
2852+
aclTensor* tmp_arange2_tensor = nullptr;
2853+
if (n_heads_log2_floor < ne2_ne3) {
2854+
// arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
2855+
start = 1;
2856+
stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
2857+
step = 2;
2858+
n_elements_arange = ne2_ne3 - n_heads_log2_floor;
2859+
int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
2860+
size_t tmp_arange2_nb[] = {faElemSize};
2861+
2862+
aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
2863+
(char*)tmp_arange_buffer +
2864+
n_heads_log2_floor * faElemSize,
2865+
faDataType, faElemSize,
2866+
tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2867+
aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
2868+
n_elements_arange);
2869+
}
2870+
2871+
// init mk_base
2872+
ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
2873+
ne2_ne3 * faElemSize);
2874+
void* tmp_mk_base_buffer = mk_base_allocator.get();
2875+
int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
2876+
size_t tmp_mk_base1_nb[] = {faElemSize};
2877+
aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
2878+
tmp_mk_base_buffer, faDataType, faElemSize,
2879+
tmp_mk_base1_ne, tmp_mk_base1_nb,
2880+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2881+
2882+
aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
2883+
2884+
aclTensor* tmp_mk_base2_tensor = nullptr;
2885+
if (n_heads_log2_floor < ne2_ne3) {
2886+
int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
2887+
size_t tmp_mk_base2_nb[] = {faElemSize};
2888+
aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
2889+
(char*)tmp_mk_base_buffer +
2890+
n_heads_log2_floor * faElemSize,
2891+
faDataType, faElemSize,
2892+
tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2893+
aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
2894+
}
2895+
2896+
// init mk
2897+
int64_t tmp_mk_base_ne[] = {ne2_ne3};
2898+
size_t tmp_mk_base_nb[] = {faElemSize};
2899+
aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
2900+
tmp_mk_base_buffer, faDataType, faElemSize,
2901+
tmp_mk_base_ne, tmp_mk_base_nb,
2902+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2903+
aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
2904+
tmp_arange_buffer, faDataType, faElemSize,
2905+
tmp_mk_base_ne, tmp_mk_base_nb,
2906+
GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
2907+
aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
2908+
2909+
// reshape mk
2910+
int64_t tmp_mk_ne[] = {1, 1, src0->ne[2], src0->ne[3]};
2911+
size_t tmp_mk_nb[GGML_MAX_DIMS];
2912+
tmp_mk_nb[0] = faElemSize;
2913+
for (int i = 1; i < GGML_MAX_DIMS; i++) {
2914+
tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
2915+
}
2916+
aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
2917+
tmp_mk_base_buffer, faDataType, faElemSize,
2918+
tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
2919+
ACL_FORMAT_ND);
2920+
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, bcast_pse_tensor, tmp_mk_tensor);
2921+
2922+
ggml_cann_release_resources(ctx, tmp_arange1_tensor, tmp_arange2_tensor,
2923+
tmp_mk_base1_tensor, tmp_mk_base2_tensor, tmp_mk_base_tensor,
2924+
tmp_arange_tensor, tmp_mk_tensor);
2925+
}
28282926
}
28292927

28302928
#ifdef DEBUG
@@ -2931,4 +3029,4 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
29313029
ggml_cann_release_resources(ctx, acl_src0_f16_tensor, acl_src1_f16_tensor, acl_src2_f16_tensor, acl_dst_f16_tensor, acl_dst_tensor);
29323030
if(src3)
29333031
ggml_cann_release_resources(ctx, bcast_pse_tensor);
2934-
}
3032+
}

0 commit comments

Comments
 (0)