@@ -2268,26 +2268,30 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22682268 * stream, and persistent buffers for rope init/cache.
22692269 * @param dst The destination ggml_tensor whose computation
22702270 * depends on the RoPE values (usually Qcur/Kcur).
2271- * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
2272- * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
22732271 * @param theta_scale Scalar exponent base for computing theta scale values.
22742272 * @param freq_scale Frequency scaling factor, applied to theta scale.
22752273 * @param attn_factor Attention scaling factor, applied to sin/cos.
22762274 * @param is_neox Whether to use Neox-style repeat strategy
22772275 * (dim expansion vs repeat_interleave).
22782276 */
22792277static void aclnn_cache_init (ggml_backend_cann_context& ctx, ggml_tensor* dst,
2280- void * sin_tensor_buffer, void * cos_tensor_buffer,
22812278 float * corr_dims, float ext_factor,
22822279 float theta_scale, float freq_scale,
22832280 float attn_factor, bool is_neox) {
2284- // int sin/cos cache, cache has different repeat method depond on
2285- // @param.is_neox
2286-
22872281 ggml_tensor* src0 = dst->src [0 ]; // input
22882282 ggml_tensor* src1 = dst->src [1 ]; // position
22892283 ggml_tensor* src2 = dst->src [2 ]; // freq_factors
22902284
2285+ if (src2 == nullptr && ctx.rope_cache .cached
2286+ && ctx.rope_cache .ext_factor == ext_factor
2287+ && ctx.rope_cache .theta_scale == theta_scale
2288+ && ctx.rope_cache .freq_scale == freq_scale
2289+ && ctx.rope_cache .attn_factor == attn_factor
2290+ && ctx.rope_cache .is_neox == is_neox) {
2291+ // use cache.
2292+ return ;
2293+ }
2294+
22912295 int64_t theta_scale_length = src0->ne [0 ] / 2 ;
22922296 int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
22932297 size_t theta_scale_nb[] = {sizeof (float ), sizeof (float ), sizeof (float ),
@@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23162320 ctx.rope_cache .freq_scale != freq_scale) {
23172321
23182322 ctx.rope_cache .theta_scale_length = theta_scale_length;
2319- ctx.rope_cache .theta_scale = theta_scale;
2320- ctx.rope_cache .freq_scale = freq_scale;
23212323
23222324 if (ctx.rope_cache .theta_scale_cache != nullptr ) {
23232325 ACL_CHECK (aclrtFree (ctx.rope_cache .theta_scale_cache ));
@@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23422344 // return MIN(1, MAX(0, y)) - 1;
23432345 yarn_ramp_allocator.alloc (theta_scale_length * sizeof (float ));
23442346 void * yarn_ramp_buffer = yarn_ramp_allocator.get ();
2345- acl_yarn_ramp_tensor = ggml_cann_create_tensor (yarn_ramp_buffer, ACL_FLOAT, sizeof (float_t ),
2347+ acl_yarn_ramp_tensor = ggml_cann_create_tensor (yarn_ramp_buffer, ACL_FLOAT, sizeof (float ),
23462348 theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23472349 float zero_value = 0 , one_value = 1 ;
23482350 float denom_safe_value = MAX (0 .001f , corr_dims[1 ] - corr_dims[0 ]);
@@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24112413 ggml_cann_release_resources (ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
24122414 }
24132415
2416+ // init sin_repeat && cos_repeat, only to accelerate first layer on each device
2417+ if (position_length > ctx.rope_cache .position_length ) {
2418+ ctx.rope_cache .position_length = position_length;
2419+ if (ctx.rope_cache .sin_cache != nullptr ) {
2420+ ACL_CHECK (aclrtFree (ctx.rope_cache .sin_cache ));
2421+ }
2422+ if (ctx.rope_cache .cos_cache != nullptr ) {
2423+ ACL_CHECK (aclrtFree (ctx.rope_cache .cos_cache ));
2424+ }
2425+ int64_t repeat_theta_length = theta_scale_length * position_length * 2 ;
2426+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .sin_cache , repeat_theta_length * sizeof (float ), ACL_MEM_MALLOC_HUGE_FIRST));
2427+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .cos_cache , repeat_theta_length * sizeof (float ), ACL_MEM_MALLOC_HUGE_FIRST));
2428+ }
2429+
24142430 // position
24152431 aclTensor* acl_position_tensor = ggml_cann_create_tensor (
24162432 src1->data , ggml_cann_type_mapping (src1->type ),
@@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24622478 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24632479 }
24642480 aclTensor* acl_sin_repeat_tensor =
2465- ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float ),
2481+ ggml_cann_create_tensor (ctx. rope_cache . sin_cache , ACL_FLOAT, sizeof (float ),
24662482 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24672483 aclTensor* acl_cos_repeat_tensor =
2468- ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float ),
2484+ ggml_cann_create_tensor (ctx. rope_cache . cos_cache , ACL_FLOAT, sizeof (float ),
24692485 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24702486
24712487 // repeat
@@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
24832499 num_repeats, output_size);
24842500 }
24852501
2502+ // Other layers use cache except first layer.
2503+ ctx.rope_cache .cached = true ;
2504+ ctx.rope_cache .ext_factor = ext_factor;
2505+ ctx.rope_cache .theta_scale = theta_scale;
2506+ ctx.rope_cache .freq_scale = freq_scale;
2507+ ctx.rope_cache .attn_factor = attn_factor;
2508+ ctx.rope_cache .is_neox = is_neox;
2509+
24862510 ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
24872511 acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
24882512 acl_cos_repeat_tensor);
@@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
25042528#endif
25052529
25062530void ggml_cann_rope (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2507- // TODO: use ascendc
2508- // Only test with LLAMA model.
25092531 ggml_tensor* src0 = dst->src [0 ]; // input
2510- ggml_tensor* src1 = dst->src [1 ];
25112532
25122533 // param
25132534 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25382559
25392560 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
25402561
2541- // sin/cos tensor length.
2542- int64_t repeat_theta_length = src0->ne [0 ] * src1->ne [0 ];
2543- ggml_cann_pool_alloc sin_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2544- ggml_cann_pool_alloc cos_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2545- void *sin_tensor_buffer = sin_tensor_allocator.get ();
2546- void *cos_tensor_buffer = cos_tensor_allocator.get ();
2547-
25482562 // init ctx.rope_cos/rope_sin cache
2549- aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
2563+ aclnn_cache_init (ctx, dst, corr_dims, ext_factor,
25502564 theta_scale, freq_scale, attn_factor, is_neox);
25512565
25522566 int64_t sin_reshape_ne[4 ] = {ne00, 1 , ne02, 1 };
@@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25562570 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
25572571 }
25582572 aclTensor* acl_sin_reshape_tensor =
2559- ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float ),
2573+ ggml_cann_create_tensor (ctx. rope_cache . sin_cache , ACL_FLOAT, sizeof (float ),
25602574 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
25612575 aclTensor* acl_cos_reshape_tensor =
2562- ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float ),
2576+ ggml_cann_create_tensor (ctx. rope_cache . cos_cache , ACL_FLOAT, sizeof (float ),
25632577 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
25642578
25652579 aclTensor* acl_src = ggml_cann_create_tensor (src0);
0 commit comments