@@ -2268,26 +2268,30 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22682268 *                           stream, and persistent buffers for rope init/cache. 
22692269 * @param dst                The destination ggml_tensor whose computation 
22702270 *                           depends on the RoPE values (usually Qcur/Kcur). 
2271-  * @param sin_tensor_buffer  Pre-allocated buffer for storing repeated sin values. 
2272-  * @param cos_tensor_buffer  Pre-allocated buffer for storing repeated cos values. 
22732271 * @param theta_scale        Scalar exponent base for computing theta scale values. 
22742272 * @param freq_scale         Frequency scaling factor, applied to theta scale. 
22752273 * @param attn_factor        Attention scaling factor, applied to sin/cos. 
22762274 * @param is_neox            Whether to use Neox-style repeat strategy 
22772275 *                           (dim expansion vs repeat_interleave). 
22782276 */  
22792277static  void  aclnn_cache_init (ggml_backend_cann_context& ctx, ggml_tensor* dst,
2280-                              void * sin_tensor_buffer, void * cos_tensor_buffer,
22812278                             float * corr_dims, float  ext_factor,
22822279                             float  theta_scale, float  freq_scale,
22832280                             float  attn_factor, bool  is_neox) {
2284-     //  int sin/cos cache, cache has different repeat method depond on
2285-     //  @param.is_neox
2286- 
22872281    ggml_tensor* src0 = dst->src [0 ];  //  input
22882282    ggml_tensor* src1 = dst->src [1 ];  //  position
22892283    ggml_tensor* src2 = dst->src [2 ];  //  freq_factors
22902284
2285+     if (src2 == nullptr  && ctx.rope_cache .cached 
2286+         && ctx.rope_cache .ext_factor  == ext_factor
2287+         && ctx.rope_cache .theta_scale  == theta_scale
2288+         && ctx.rope_cache .freq_scale  == freq_scale
2289+         && ctx.rope_cache .attn_factor  == attn_factor
2290+         && ctx.rope_cache .is_neox  == is_neox) {
2291+         //  use cache.
2292+         return ;
2293+     }
2294+ 
22912295    int64_t  theta_scale_length = src0->ne [0 ] / 2 ;
22922296    int64_t  theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
22932297    size_t  theta_scale_nb[] = {sizeof (float ), sizeof (float ), sizeof (float ),
@@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23162320        ctx.rope_cache .freq_scale  != freq_scale) {
23172321
23182322        ctx.rope_cache .theta_scale_length  = theta_scale_length;
2319-         ctx.rope_cache .theta_scale  = theta_scale;
2320-         ctx.rope_cache .freq_scale  = freq_scale;
23212323
23222324        if  (ctx.rope_cache .theta_scale_cache  != nullptr ) {
23232325            ACL_CHECK (aclrtFree (ctx.rope_cache .theta_scale_cache ));
@@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23422344            //  return MIN(1, MAX(0, y)) - 1;
23432345            yarn_ramp_allocator.alloc (theta_scale_length * sizeof (float ));
23442346            void * yarn_ramp_buffer = yarn_ramp_allocator.get ();
2345-             acl_yarn_ramp_tensor = ggml_cann_create_tensor (yarn_ramp_buffer, ACL_FLOAT, sizeof (float_t ),
2347+             acl_yarn_ramp_tensor = ggml_cann_create_tensor (yarn_ramp_buffer, ACL_FLOAT, sizeof (float ),
23462348                                           theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23472349            float  zero_value = 0 , one_value = 1 ;
23482350            float  denom_safe_value = MAX (0 .001f , corr_dims[1 ] - corr_dims[0 ]);
@@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24112413        ggml_cann_release_resources (ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
24122414    }
24132415
2416+     //  init sin_repeat && cos_repeat, only to accelerate first layer on each device
2417+     if  (position_length > ctx.rope_cache .position_length ) {
2418+         ctx.rope_cache .position_length  = position_length;
2419+         if  (ctx.rope_cache .sin_cache  != nullptr ) {
2420+             ACL_CHECK (aclrtFree (ctx.rope_cache .sin_cache ));
2421+         }
2422+         if  (ctx.rope_cache .cos_cache  != nullptr ) {
2423+             ACL_CHECK (aclrtFree (ctx.rope_cache .cos_cache ));
2424+         }
2425+         int64_t  repeat_theta_length = theta_scale_length * position_length * 2 ;
2426+         ACL_CHECK (aclrtMalloc (&ctx.rope_cache .sin_cache , repeat_theta_length * sizeof (float ), ACL_MEM_MALLOC_HUGE_FIRST));
2427+         ACL_CHECK (aclrtMalloc (&ctx.rope_cache .cos_cache , repeat_theta_length * sizeof (float ), ACL_MEM_MALLOC_HUGE_FIRST));
2428+     }
2429+ 
24142430    //  position
24152431    aclTensor* acl_position_tensor = ggml_cann_create_tensor (
24162432        src1->data , ggml_cann_type_mapping (src1->type ),
@@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24622478        sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24632479    }
24642480    aclTensor* acl_sin_repeat_tensor =
2465-         ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float ),
2481+         ggml_cann_create_tensor (ctx. rope_cache . sin_cache , ACL_FLOAT, sizeof (float ),
24662482                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24672483    aclTensor* acl_cos_repeat_tensor =
2468-         ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float ),
2484+         ggml_cann_create_tensor (ctx. rope_cache . cos_cache , ACL_FLOAT, sizeof (float ),
24692485                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24702486
24712487    //  repeat
@@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24832499                                num_repeats, output_size);
24842500    }
24852501
2502+     //  Other layers use cache except first layer.
2503+     ctx.rope_cache .cached  = true ;
2504+     ctx.rope_cache .ext_factor  = ext_factor;
2505+     ctx.rope_cache .theta_scale  = theta_scale;
2506+     ctx.rope_cache .freq_scale  = freq_scale;
2507+     ctx.rope_cache .attn_factor  = attn_factor;
2508+     ctx.rope_cache .is_neox  = is_neox;
2509+ 
24862510    ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
24872511        acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
24882512        acl_cos_repeat_tensor);
@@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
25042528#endif 
25052529
25062530void  ggml_cann_rope (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2507-     //  TODO: use ascendc
2508-     //  Only test with LLAMA model.
25092531    ggml_tensor* src0 = dst->src [0 ];  //  input
2510-     ggml_tensor* src1 = dst->src [1 ];
25112532
25122533    //  param
25132534    float  freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25382559
25392560    const  bool  is_neox = mode & GGML_ROPE_TYPE_NEOX;
25402561
2541-     //  sin/cos tensor length.
2542-     int64_t  repeat_theta_length = src0->ne [0 ] * src1->ne [0 ];
2543-     ggml_cann_pool_alloc sin_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2544-     ggml_cann_pool_alloc cos_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2545-     void  *sin_tensor_buffer = sin_tensor_allocator.get ();
2546-     void  *cos_tensor_buffer = cos_tensor_allocator.get ();
2547- 
25482562    //  init ctx.rope_cos/rope_sin cache
2549-     aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer,  corr_dims, ext_factor,
2563+     aclnn_cache_init (ctx, dst, corr_dims, ext_factor,
25502564                    theta_scale, freq_scale, attn_factor, is_neox);
25512565
25522566    int64_t  sin_reshape_ne[4 ] = {ne00, 1 , ne02, 1 };
@@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25562570        sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
25572571    }
25582572    aclTensor* acl_sin_reshape_tensor =
2559-         ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float ),
2573+         ggml_cann_create_tensor (ctx. rope_cache . sin_cache , ACL_FLOAT, sizeof (float ),
25602574                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
25612575    aclTensor* acl_cos_reshape_tensor =
2562-         ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float ),
2576+         ggml_cann_create_tensor (ctx. rope_cache . cos_cache , ACL_FLOAT, sizeof (float ),
25632577                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
25642578
25652579    aclTensor* acl_src = ggml_cann_create_tensor (src0);
0 commit comments