@@ -2965,7 +2965,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
29652965 aclTensor* acl_cos_repeat_tensor,
29662966 aclTensor* acl_sin_repeat_tensor,
29672967 float theta_scale, float freq_scale,
2968- bool is_neox) {
2968+ float attn_factor, bool is_neox) {
29692969 // int sin/cos cache, cache has different repeat method depond on
29702970 // @param.is_neox
29712971
@@ -3017,6 +3017,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
30173017 ggml_type_size (src2->type ), arange_ne, arange_nb, GGML_MAX_DIMS);
30183018 aclnn_div_tensor (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor,
30193019 nullptr , true );
3020+ ACL_CHECK (aclDestroyTensor (acl_freq_factors_tensor));
30203021 }
30213022
30223023 // position
@@ -3047,16 +3048,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
30473048 aclnn_mul (ctx, acl_position_tensor, acl_theta_scale_tensor,
30483049 acl_theta_tensor);
30493050
3050- // // power[] * position[] * freq_scale / freq_factors[]
3051- // ggml_cann_pool_alloc theta_final_allocator(ctx.pool(),
3052- // theta_length *
3053- // sizeof(float_t));
3054- // aclTensor* acl_theat_final_tensor = aclnn_zero(
3055- // ctx, theta_final_allocator.get(), sizeof(float_t) * theta_length,
3056- // theta_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t));
3057- // aclnn_inplace_addcdiv(ctx, acl_theat_final_tensor, acl_theta_tensor,
3058- // acl_freq_factors_tensor, freq_scale);
3059-
30603051 // permute: [0,1,2,3]->[0,2,1,3]
30613052 int64_t permute_ne[] = {arange_length, 1 , position_length, 1 };
30623053 size_t permute_nb[GGML_MAX_DIMS];
@@ -3092,6 +3083,12 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
30923083 GGML_MAX_DIMS, ACL_FORMAT_ND);
30933084 aclnn_cos (ctx, acl_permute_tensor, acl_cos_tensor);
30943085
3086+ // attn_factor
3087+ if (attn_factor != 1 ) {
3088+ aclnn_muls (ctx, acl_sin_tensor, attn_factor, nullptr , true );
3089+ aclnn_muls (ctx, acl_cos_tensor, attn_factor, nullptr , true );
3090+ }
3091+
30953092 // repeat
30963093 if (is_neox) {
30973094 int64_t repeatsArray[] = {1 , 1 , 1 , 2 };
@@ -3155,15 +3152,11 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31553152 memcpy (&beta_fast, (int32_t *)dst->op_params + 9 , sizeof (float ));
31563153 memcpy (&beta_slow, (int32_t *)dst->op_params + 10 , sizeof (float ));
31573154
3158- // TODO: attn_factor != 1
3159- GGML_ASSERT (attn_factor == 1 );
31603155 // TODO: n_dims <= ne0
31613156 GGML_ASSERT (n_dims == ne0);
31623157 GGML_ASSERT (n_dims % 2 == 0 );
31633158 // TODO: ext_factor != 0
31643159 GGML_ASSERT (ext_factor == 0 );
3165- // TODO: type == GGML_TYPE_F16
3166- GGML_ASSERT (src0->type == GGML_TYPE_F32);
31673160
31683161 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
31693162
@@ -3194,7 +3187,215 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
31943187 ggml_cann_create_tensor (cos_buffer, ACL_FLOAT, sizeof (float_t ),
31953188 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
31963189 aclnn_cache_init (ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
3197- theta_scale, freq_scale, is_neox);
3190+ theta_scale, freq_scale, attn_factor, is_neox);
3191+
3192+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
3193+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
3194+
3195+ #ifdef ASCEND_310P
3196+ // Special ROPE operation for 310P
3197+
3198+ // roll input
3199+ void * input_roll_buffer;
3200+ aclTensor* acl_minus_one_tensor;
3201+ void * minus_one_scale_buffer = nullptr ;
3202+ ggml_cann_pool_alloc roll_allocator (ctx.pool (), ggml_nbytes (src0));
3203+ ggml_cann_pool_alloc minus_one_scale_allocator (
3204+ ctx.pool (), sizeof (float_t ) * src0->ne [0 ]);
3205+ if (!is_neox) {
3206+ // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
3207+ input_roll_buffer = roll_allocator.get ();
3208+ int64_t input_roll_ne[4 ] = {2 , src0->ne [1 ] * (src0->ne [0 ] / 2 ),
3209+ src0->ne [2 ], src0->ne [3 ]};
3210+ size_t input_roll_nb[GGML_MAX_DIMS];
3211+ input_roll_nb[0 ] = ggml_type_size (src0->type );
3212+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3213+ input_roll_nb[i] = input_roll_nb[i - 1 ] * input_roll_ne[i - 1 ];
3214+ }
3215+ aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor (
3216+ input_roll_buffer, ggml_cann_type_mapping (src0->type ),
3217+ ggml_type_size (src0->type ), input_roll_ne, input_roll_nb,
3218+ GGML_MAX_DIMS);
3219+ aclTensor* acl_input_tensor = ggml_cann_create_tensor (
3220+ src0->data , ggml_cann_type_mapping (src0->type ),
3221+ ggml_type_size (src0->type ), input_roll_ne, input_roll_nb,
3222+ GGML_MAX_DIMS);
3223+
3224+ int64_t shifts[] = {1 };
3225+ int64_t dims[] = {3 };
3226+ aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
3227+ ACL_CHECK (aclDestroyTensor (acl_input_roll_tensor));
3228+ ACL_CHECK (aclDestroyTensor (acl_input_tensor));
3229+
3230+ // init [-1, 1, -1, 1, ...]
3231+ minus_one_scale_buffer = minus_one_scale_allocator.get ();
3232+
3233+ int64_t minus_one_ne[4 ] = {src0->ne [0 ], 1 , 1 , 1 };
3234+ size_t minus_one_nb[GGML_MAX_DIMS];
3235+ minus_one_nb[0 ] = sizeof (float_t );
3236+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3237+ minus_one_nb[i] = minus_one_nb[i - 1 ] * minus_one_ne[i - 1 ];
3238+ }
3239+ acl_minus_one_tensor = aclnn_values (
3240+ ctx, minus_one_scale_buffer, sizeof (float_t ) * src0->ne [0 ],
3241+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof (float_t ), 1 );
3242+ int64_t dim = 3 ;
3243+ int64_t * index = new int64_t [src0->ne [0 ]];
3244+ for (int i = 0 ; i < src0->ne [0 ]; i++) {
3245+ index[i] = i / 2 * 2 ;
3246+ }
3247+ int64_t index_num = src0->ne [0 ];
3248+ float value = -1 ;
3249+ aclnn_index_fill_tensor (ctx, acl_minus_one_tensor, dim, index,
3250+ index_num, value);
3251+ } else {
3252+ // roll input: [q0,q1,q2,...] ->
3253+ // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
3254+ input_roll_buffer = roll_allocator.get ();
3255+ aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor (
3256+ input_roll_buffer, ggml_cann_type_mapping (src0->type ),
3257+ ggml_type_size (src0->type ), src0->ne , src0->nb , GGML_MAX_DIMS);
3258+ aclTensor* acl_input_tensor = ggml_cann_create_tensor (src0);
3259+
3260+ int64_t shifts[] = {src0->ne [0 ] / 2 };
3261+ int64_t dims[] = {3 };
3262+ aclnn_roll (ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
3263+
3264+ ACL_CHECK (aclDestroyTensor (acl_input_roll_tensor));
3265+ ACL_CHECK (aclDestroyTensor (acl_input_tensor));
3266+ // init [-1, -1, -1, 1, 1,1,...]
3267+ minus_one_scale_buffer = minus_one_scale_allocator.get ();
3268+ int64_t minus_one_ne[4 ] = {src0->ne [0 ], 1 , 1 , 1 };
3269+ size_t minus_one_nb[GGML_MAX_DIMS];
3270+ minus_one_nb[0 ] = sizeof (float_t );
3271+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3272+ minus_one_nb[i] = minus_one_nb[i - 1 ] * minus_one_ne[i - 1 ];
3273+ }
3274+ acl_minus_one_tensor = aclnn_values (
3275+ ctx, minus_one_scale_buffer, sizeof (float_t ) * src0->ne [0 ],
3276+ minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof (float_t ), 1 );
3277+ // -1 * first half
3278+ int64_t first_half_ne[4 ] = {src0->ne [0 ] / 2 , 1 , 1 , 1 };
3279+ size_t first_half_nb[GGML_MAX_DIMS];
3280+ first_half_nb[0 ] = sizeof (float_t );
3281+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3282+ first_half_nb[i] = first_half_nb[i - 1 ] * first_half_ne[i - 1 ];
3283+ }
3284+ aclTensor* acl_first_half_tensor = ggml_cann_create_tensor (
3285+ minus_one_scale_buffer, ACL_FLOAT, sizeof (float_t ), first_half_ne,
3286+ first_half_nb, GGML_MAX_DIMS);
3287+ bool inplace = true ;
3288+ float scale = -1 ;
3289+ aclnn_muls (ctx, acl_first_half_tensor, scale, nullptr , inplace);
3290+ ACL_CHECK (aclDestroyTensor (acl_first_half_tensor));
3291+ }
3292+
3293+ // TODO: n_dims < ne0
3294+ GGML_ASSERT (n_dims == src0->ne [0 ]);
3295+
3296+ // input * scale
3297+ ggml_cann_pool_alloc roll_mul_scale_allocator (ctx.pool (),
3298+ ggml_nbytes (src0));
3299+ void * input_roll_mul_scale_buffer = roll_mul_scale_allocator.get ();
3300+ size_t input_nb[GGML_MAX_DIMS];
3301+ input_nb[0 ] = ggml_type_size (src0->type );
3302+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3303+ input_nb[i] = input_nb[i - 1 ] * src0->ne [i - 1 ];
3304+ }
3305+ aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor (
3306+ input_roll_mul_scale_buffer, ggml_cann_type_mapping (src0->type ),
3307+ ggml_type_size (src0->type ), src0->ne , input_nb, GGML_MAX_DIMS);
3308+ aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor (
3309+ input_roll_buffer, ggml_cann_type_mapping (src0->type ),
3310+ ggml_type_size (src0->type ), src0->ne , input_nb, GGML_MAX_DIMS);
3311+
3312+ aclnn_mul (ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
3313+ acl_input_roll_mul_scale_tensor);
3314+
3315+ // output
3316+ void * output_fp32_buffer;
3317+ if (src0->type == GGML_TYPE_F32) {
3318+ aclnn_inplace_mul (ctx, acl_src, acl_cos_reshape_tensor);
3319+ aclnn_inplace_mul (ctx, acl_input_roll_mul_scale_tensor,
3320+ acl_sin_reshape_tensor);
3321+ aclnn_add (ctx, acl_src, acl_input_roll_mul_scale_tensor, acl_dst);
3322+ // TODO: ne0 != n_dims in mode2
3323+ } else if (src0->type == GGML_TYPE_F16) {
3324+ size_t input_fp32_nb[GGML_MAX_DIMS];
3325+ input_fp32_nb[0 ] = sizeof (float_t );
3326+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3327+ input_fp32_nb[i] = input_fp32_nb[i - 1 ] * dst->ne [i - 1 ];
3328+ }
3329+ ggml_cann_pool_alloc fp32_allocator1 (
3330+ ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3331+ void * input_fp32_buffer1 = fp32_allocator1.get ();
3332+ aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor (
3333+ input_fp32_buffer1, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3334+ input_fp32_nb, GGML_MAX_DIMS);
3335+ ggml_cann_pool_alloc fp32_allocator2 (
3336+ ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3337+ void * input_fp32_buffer2 = fp32_allocator2.get ();
3338+ aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor (
3339+ input_fp32_buffer2, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3340+ input_fp32_nb, GGML_MAX_DIMS);
3341+
3342+ ggml_cann_pool_alloc fp32_allocator (
3343+ ctx.pool (), ggml_nelements (dst) * sizeof (float_t ));
3344+ output_fp32_buffer = fp32_allocator.get ();
3345+ aclTensor* output_fp32_tensor = ggml_cann_create_tensor (
3346+ output_fp32_buffer, ACL_FLOAT, sizeof (float_t ), dst->ne ,
3347+ input_fp32_nb, GGML_MAX_DIMS);
3348+ aclnn_mul (ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
3349+ aclnn_mul (ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
3350+ input_fp32_tensor2);
3351+ aclnn_add (ctx, input_fp32_tensor1, input_fp32_tensor2,
3352+ output_fp32_tensor);
3353+ aclnn_cast (ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
3354+
3355+ ACL_CHECK (aclDestroyTensor (input_fp32_tensor1));
3356+ ACL_CHECK (aclDestroyTensor (input_fp32_tensor2));
3357+ ACL_CHECK (aclDestroyTensor (output_fp32_tensor));
3358+ ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
3359+ ACL_CHECK (aclDestroyTensor (acl_minus_one_tensor));
3360+ ACL_CHECK (aclDestroyTensor (acl_input_roll_mul_scale_tensor));
3361+ ACL_CHECK (aclDestroyTensor (acl_input_roll_reshape_tensor));
3362+ ACL_CHECK (aclDestroyTensor (acl_src));
3363+ }
3364+ return ;
3365+ #endif
3366+
3367+ // src0 == GGML_TYPE_F16
3368+ // TODO: optimization this `if` code
3369+ if (src0->type == GGML_TYPE_F16) {
3370+ ggml_cann_pool_alloc sin_final_allocator (
3371+ ctx.pool (), src0->ne [0 ] * src0->ne [2 ] * ggml_type_size (src0->type ));
3372+ ggml_cann_pool_alloc cos_final_allocator (
3373+ ctx.pool (), src0->ne [0 ] * src0->ne [2 ] * ggml_type_size (src0->type ));
3374+ void * sin_final_buffer = sin_final_allocator.get ();
3375+ void * cos_final_buffer = cos_final_allocator.get ();
3376+
3377+ int64_t sin_final_ne[4 ] = {src0->ne [0 ], 1 , src0->ne [2 ], 1 };
3378+ size_t sin_final_nb[GGML_MAX_DIMS];
3379+ sin_final_nb[0 ] = ggml_type_size (src0->type );
3380+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
3381+ sin_final_nb[i] = sin_final_nb[i - 1 ] * sin_final_ne[i - 1 ];
3382+ }
3383+ aclTensor* acl_sin_final_tensor =
3384+ ggml_cann_create_tensor (sin_final_buffer, ggml_cann_type_mapping (src0->type ), ggml_type_size (src0->type ),
3385+ sin_final_ne, sin_final_nb, GGML_MAX_DIMS);
3386+ aclTensor* acl_cos_final_tensor =
3387+ ggml_cann_create_tensor (cos_final_buffer, ggml_cann_type_mapping (src0->type ), ggml_type_size (src0->type ),
3388+ sin_final_ne, sin_final_nb, GGML_MAX_DIMS);
3389+
3390+ aclnn_cast (ctx, acl_sin_reshape_tensor, acl_sin_final_tensor,
3391+ ggml_cann_type_mapping (src0->type ));
3392+ aclnn_cast (ctx, acl_cos_reshape_tensor, acl_cos_final_tensor,
3393+ ggml_cann_type_mapping (src0->type ));
3394+ ACL_CHECK (aclDestroyTensor (acl_cos_reshape_tensor));
3395+ ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
3396+ acl_sin_reshape_tensor = acl_sin_final_tensor;
3397+ acl_cos_reshape_tensor = acl_cos_final_tensor;
3398+ }
31983399
31993400 uint64_t workspaceSize = 0 ;
32003401 aclOpExecutor* executor;
@@ -3206,10 +3407,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
32063407 acl_mode = 1 ;
32073408 }
32083409
3209- aclTensor* acl_x = ggml_cann_create_tensor (src0);
3210- aclTensor* acl_dst = ggml_cann_create_tensor (dst);
32113410 ACL_CHECK (aclnnRotaryPositionEmbeddingGetWorkspaceSize (
3212- acl_x , acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
3411+ acl_src , acl_cos_reshape_tensor, acl_sin_reshape_tensor, acl_mode,
32133412 acl_dst, &workspaceSize, &executor));
32143413 if (workspaceSize > 0 ) {
32153414 ggml_cann_pool_alloc workspace_allocator (ctx.pool (), workspaceSize);
@@ -3219,7 +3418,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
32193418 ACL_CHECK (aclnnRotaryPositionEmbedding (workspaceAddr, workspaceSize,
32203419 executor, ctx.stream ()));
32213420
3222- ACL_CHECK (aclDestroyTensor (acl_x ));
3421+ ACL_CHECK (aclDestroyTensor (acl_src ));
32233422 ACL_CHECK (aclDestroyTensor (acl_cos_reshape_tensor));
32243423 ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
32253424 ACL_CHECK (aclDestroyTensor (acl_dst));
0 commit comments