7070#include < aclnnop/aclnn_zero.h>
7171#include < aclnnop/aclnn_index_copy.h>
7272#include < aclnnop/aclnn_index_select.h>
73+ #include < aclnnop/aclnn_clamp.h>
74+ #include < aclnnop/aclnn_threshold.h>
7375#include < float.h>
7476
7577#include < cmath>
@@ -2263,6 +2265,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22632265 */
22642266static void aclnn_cache_init (ggml_backend_cann_context& ctx, ggml_tensor* dst,
22652267 void * sin_tensor_buffer, void * cos_tensor_buffer,
2268+ float * corr_dims, float ext_factor,
22662269 float theta_scale, float freq_scale,
22672270 float attn_factor, bool is_neox) {
22682271 // int sin/cos cache, cache has different repeat method depond on
@@ -2318,16 +2321,60 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23182321 float n_elements = theta_scale_length;
23192322 aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23202323
2324+ ggml_cann_pool_alloc yarn_ramp_allocator (ctx.pool ());
2325+ aclTensor* acl_yarn_ramp_tensor = nullptr ;
2326+ if (ext_factor != 0 ) {
2327+ // -rope_yarn_ramp
2328+ // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
2329+ // return MIN(1, MAX(0, y)) - 1;
2330+ yarn_ramp_allocator.alloc (theta_scale_length * sizeof (float ));
2331+ void * yarn_ramp_buffer = yarn_ramp_allocator.get ();
2332+ acl_yarn_ramp_tensor = ggml_cann_create_tensor (yarn_ramp_buffer, ACL_FLOAT, sizeof (float_t ),
2333+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2334+ float zero_value = 0 , one_value = 1 ;
2335+ float denom_safe_value = MAX (0 .001f , corr_dims[1 ] - corr_dims[0 ]);
2336+ aclScalar* low = aclCreateScalar (&corr_dims[0 ], aclDataType::ACL_FLOAT);
2337+ aclScalar* zero = aclCreateScalar (&zero_value, aclDataType::ACL_FLOAT);
2338+ aclScalar* one = aclCreateScalar (&one_value, aclDataType::ACL_FLOAT);
2339+ aclScalar* denom_safe = aclCreateScalar (&denom_safe_value, aclDataType::ACL_FLOAT);
2340+ aclScalar* ext_factor_sc = aclCreateScalar (&ext_factor, aclDataType::ACL_FLOAT);
2341+
2342+ GGML_CANN_CALL_ACLNN_OP (ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
2343+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
2344+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
2345+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
2346+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
2347+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
2348+
2349+ // theta_interp = freq_scale * theta_extrap;
2350+ // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
2351+ // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
2352+ // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
2353+ // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
2354+ //
2355+ // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
2356+ // cache freq_scale + (freq_scale - 1) * ramp_mix
2357+ float freq_scale_1 = freq_scale - 1 ;
2358+ aclScalar* freq_scale_sc = aclCreateScalar (&freq_scale, aclDataType::ACL_FLOAT);
2359+ aclScalar* freq_scale_1_sc = aclCreateScalar (&freq_scale_1, aclDataType::ACL_FLOAT);
2360+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
2361+ GGML_CANN_CALL_ACLNN_OP (ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
2362+
2363+ ggml_cann_release_resources (ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
2364+ }
2365+
23212366 // power
23222367 aclScalar* acl_theta_scale = aclCreateScalar (&theta_scale, aclDataType::ACL_FLOAT);
23232368 GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
23242369 acl_theta_scale_tensor);
23252370
2326- // freq_scale
2327- if (freq_scale != 1 ) {
2371+ if (ext_factor != 0 ) {
2372+ aclnn_mul (ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
2373+ } else if (freq_scale != 1 ) {
23282374 aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
23292375 }
2330- ggml_cann_release_resources (ctx, acl_theta_scale);
2376+
2377+ ggml_cann_release_resources (ctx, acl_yarn_ramp_tensor, acl_theta_scale);
23312378 } else {
23322379 // use cache
23332380 acl_theta_scale_tensor =
@@ -2385,6 +2432,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23852432 GGML_MAX_DIMS, ACL_FORMAT_ND);
23862433 aclnn_cos (ctx, acl_theta_tensor, acl_cos_tensor);
23872434
2435+ if (ext_factor != 0 ) {
2436+ attn_factor *= 1 .0f + 0 .1f * logf (1 .0f / freq_scale);
2437+ }
2438+
23882439 // attn_factor
23892440 if (attn_factor != 1 ) {
23902441 aclnn_muls (ctx, acl_sin_tensor, attn_factor, nullptr , true );
@@ -2465,8 +2516,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24652516 // TODO: n_dims <= ne0
24662517 GGML_ASSERT (n_dims == ne0);
24672518 GGML_ASSERT (n_dims % 2 == 0 );
2468- // TODO: ext_factor != 0
2469- GGML_ASSERT (ext_factor == 0 );
24702519
24712520 const float theta_scale = powf (freq_base, -2 .0f / n_dims);
24722521
@@ -2484,7 +2533,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24842533 void *cos_tensor_buffer = cos_tensor_allocator.get ();
24852534
24862535 // init ctx.rope_cos/rope_sin cache
2487- aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
2536+ aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
24882537 theta_scale, freq_scale, attn_factor, is_neox);
24892538
24902539 int64_t sin_reshape_ne[4 ] = {ne00, 1 , ne02, 1 };
0 commit comments