@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
964964 }
965965 aclTensor* acl_gamma = get_f32_cache_acl_tensor (
966966 ctx,
967- &ctx.f32_one_cache ,
968- ctx.f32_one_cache_element ,
967+ &ctx.rms_norm_one_tensor_cache . cache ,
968+ ctx.rms_norm_one_tensor_cache . size ,
969969 src->ne ,
970970 acl_gamma_nb,
971971 1 , // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
980980 }
981981 aclTensor* acl_rstd = get_f32_cache_acl_tensor (
982982 ctx,
983- &ctx.f32_zero_cache ,
984- ctx.f32_zero_cache_element ,
983+ &ctx.rms_norm_zero_tensor_cache . cache ,
984+ ctx.rms_norm_zero_tensor_cache . size ,
985985 src->ne ,
986986 acl_rstd_nb,
987987 GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22482248 * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
22492249 * 6. Expand sin/cos values by repeat or repeat_interleave depending
22502250 * on whether @param is_neox is enabled.
2251- * 7. Store the computed values into persistent buffers
2252- * (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
2253- *
2254- * @param ctx The CANN backend context, holding memory pool,
2255- * stream, and persistent buffers for rope init/cache .
2256- * @param dst The destination ggml_tensor whose computation
2257- * depends on the cached RoPE values (usually Qcur/Kcur) .
2258- * @param theta_scale Scalar exponent base for computing theta scale values.
2259- * @param freq_scale Frequency scaling factor, applied to theta scale.
2260- * @param attn_factor Attention scaling factor, applied to sin/cos.
2261- * @param is_neox Whether to use Neox-style repeat strategy
2262- * (dim expansion vs repeat_interleave).
2251+ *
2252+ * @param ctx The CANN backend context, holding memory pool,
2253+ * stream, and persistent buffers for rope init/cache.
2254+ * @param dst The destination ggml_tensor whose computation
2255+ * depends on the RoPE values (usually Qcur/Kcur) .
2256+ * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
2257+ * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
2258+ * @param theta_scale Scalar exponent base for computing theta scale values.
2259+ * @param freq_scale Frequency scaling factor, applied to theta scale.
2260+ * @param attn_factor Attention scaling factor, applied to sin/cos.
2261+ * @param is_neox Whether to use Neox-style repeat strategy
2262+ * (dim expansion vs repeat_interleave).
22632263 */
22642264static void aclnn_cache_init (ggml_backend_cann_context& ctx, ggml_tensor* dst,
2265+ void * sin_tensor_buffer, void * cos_tensor_buffer,
22652266 float theta_scale, float freq_scale,
22662267 float attn_factor, bool is_neox) {
22672268 // int sin/cos cache, cache has different repeat method depond on
22682269 // @param.is_neox
2269- bool is_q = (std::strncmp (dst->name , " Qcur-" , 5 ) == 0 );
2270- bool is_k = (std::strncmp (dst->name , " Kcur-" , 5 ) == 0 );
2271-
2272- // used for accuracy testing
2273- bool is_attention = is_q || is_k;
2274-
2275- // just compute in first layer in attention
2276- bool is_fisrt_layer = (std::strncmp (dst->name , " Qcur-0" , GGML_MAX_NAME) == 0 );
2277- if (is_attention && !is_fisrt_layer) {
2278- return ;
2279- }
22802270
22812271 ggml_tensor* src0 = dst->src [0 ]; // input
22822272 ggml_tensor* src1 = dst->src [1 ]; // position
22832273 ggml_tensor* src2 = dst->src [2 ]; // freq_factors
22842274
2285- GGML_TENSOR_BINARY_OP_LOCALS
2286-
2287- int64_t theta_scale_length = ne00 / 2 ;
2275+ int64_t theta_scale_length = src0->ne [0 ] / 2 ;
22882276 int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
22892277 size_t theta_scale_nb[] = {sizeof (float_t ), sizeof (float_t ), sizeof (float_t ),
22902278 theta_scale_length * sizeof (float_t )};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23022290 theta_nb[i] = theta_nb[i - 1 ] * theta_ne[i - 1 ];
23032291 }
23042292
2305- // init theta scale, just one time
2306- if (ctx.rope_init_ptr == nullptr || !is_attention) {
2307- // theta_scale arange, [0,1,...,ne00/2 - 1]
2308- if (ctx.rope_init_ptr != nullptr ){
2309- ACL_CHECK (aclrtFree (ctx.rope_init_ptr ));
2293+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2294+ aclTensor* acl_theta_scale_tensor = nullptr ;
2295+ // cache theta scale
2296+ if (ctx.rope_cache .theta_scale_length != theta_scale_length ||
2297+ // theta_scale and freq_scale should not change during the current token inference process,
2298+ // so we can directly use == here instead of comparing the absolute difference.
2299+ ctx.rope_cache .theta_scale != theta_scale ||
2300+ ctx.rope_cache .freq_scale != freq_scale) {
2301+
2302+ ctx.rope_cache .theta_scale_length = theta_scale_length;
2303+ ctx.rope_cache .theta_scale = theta_scale;
2304+ ctx.rope_cache .freq_scale = freq_scale;
2305+
2306+ if (ctx.rope_cache .theta_scale_cache != nullptr ) {
2307+ ACL_CHECK (aclrtFree (ctx.rope_cache .theta_scale_cache ));
23102308 }
2311- ACL_CHECK (aclrtMalloc (&ctx.rope_init_ptr , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2309+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache . theta_scale_cache , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
23122310
2313- aclTensor* acl_theta_scale_tensor =
2314- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2311+ acl_theta_scale_tensor =
2312+ ggml_cann_create_tensor (ctx.rope_cache . theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
23152313 theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2314+
23162315 float start = 0 ;
23172316 float step = 1 ;
2318- float stop = ne00 / 2 ;
2319- float n_elements = ne00 / 2 ;
2317+ float stop = theta_scale_length ;
2318+ float n_elements = theta_scale_length ;
23202319 aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23212320
23222321 // power
@@ -2328,35 +2327,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23282327 if (freq_scale != 1 ) {
23292328 aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
23302329 }
2331-
2332- // freq_factors
2333- if (src2) {
2334- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2335- src2->data , ggml_cann_type_mapping (src2->type ),
2336- ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2337- aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2338- ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2339- }
2340- // release
2341- ggml_cann_release_resources (ctx, acl_theta_scale_tensor,acl_theta_scale);
2330+ ggml_cann_release_resources (ctx, acl_theta_scale);
2331+ } else {
2332+ // use cache
2333+ acl_theta_scale_tensor =
2334+ ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
2335+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23422336 }
23432337
2344- // init sin_repeat && cos_repeat, one token just init in 0 layer
2345- if (position_length > ctx.max_prompt_length ) {
2346- ctx.max_prompt_length = position_length;
2347- int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2 ;
2348- if (ctx.rope_sin_ptr != nullptr ) {
2349- ACL_CHECK (aclrtFree (ctx.rope_sin_ptr ));
2350- ACL_CHECK (aclrtFree (ctx.rope_cos_ptr ));
2351- }
2352- ACL_CHECK (aclrtMalloc (&ctx.rope_sin_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2353- ACL_CHECK (aclrtMalloc (&ctx.rope_cos_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2338+ ggml_cann_pool_alloc freq_fac_res_allocator (ctx.pool ());
2339+ // freq_factors
2340+ if (src2) {
2341+ freq_fac_res_allocator.alloc (theta_scale_length * sizeof (float_t ));
2342+ void * freq_fac_res_ptr = freq_fac_res_allocator.get ();
2343+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2344+ src2->data , ggml_cann_type_mapping (src2->type ),
2345+ ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2346+ aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor (
2347+ freq_fac_res_ptr, ACL_FLOAT, sizeof (float_t ),
2348+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2349+ aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2350+ std::swap (acl_theta_scale_tensor, acl_freq_fac_res_tensor);
2351+ ggml_cann_release_resources (ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
23542352 }
23552353
2356- aclTensor* acl_theta_scale_tensor =
2357- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2358- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2359-
23602354 // position
23612355 aclTensor* acl_position_tensor = ggml_cann_create_tensor (
23622356 src1->data , ggml_cann_type_mapping (src1->type ),
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23972391 aclnn_muls (ctx, acl_cos_tensor, attn_factor, nullptr , true );
23982392 }
23992393
2400- int64_t sin_reshape_ne[4 ] = {ne00 , 1 , ne02 , 1 };
2394+ int64_t sin_reshape_ne[4 ] = {src0-> ne [ 0 ] , 1 , src0-> ne [ 2 ] , 1 };
24012395 size_t sin_reshape_nb[GGML_MAX_DIMS];
24022396 sin_reshape_nb[0 ] = sizeof (float_t );
24032397 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
24042398 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24052399 }
24062400 aclTensor* acl_sin_repeat_tensor =
2407- ggml_cann_create_tensor (ctx. rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2401+ ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
24082402 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24092403 aclTensor* acl_cos_repeat_tensor =
2410- ggml_cann_create_tensor (ctx. rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2404+ ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
24112405 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24122406
24132407 // repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24492443 // TODO: use ascendc
24502444 // Only test with LLAMA model.
24512445 ggml_tensor* src0 = dst->src [0 ]; // input
2446+ ggml_tensor* src1 = dst->src [1 ];
24522447
24532448 // param
24542449 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24812476
24822477 const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
24832478
2479+ // sin/cos tensor length.
2480+ int64_t repeat_theta_length = src0->ne [0 ] * src1->ne [0 ];
2481+ ggml_cann_pool_alloc sin_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2482+ ggml_cann_pool_alloc cos_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2483+ void *sin_tensor_buffer = sin_tensor_allocator.get ();
2484+ void *cos_tensor_buffer = cos_tensor_allocator.get ();
2485+
24842486 // init ctx.rope_cos/rope_sin cache
2485- aclnn_cache_init (ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
2487+ aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
2488+ theta_scale, freq_scale, attn_factor, is_neox);
24862489
24872490 int64_t sin_reshape_ne[4 ] = {ne00, 1 , ne02, 1 };
24882491 size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24912494 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24922495 }
24932496 aclTensor* acl_sin_reshape_tensor =
2494- ggml_cann_create_tensor (ctx. rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2497+ ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
24952498 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24962499 aclTensor* acl_cos_reshape_tensor =
2497- ggml_cann_create_tensor (ctx. rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2500+ ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
24982501 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24992502
25002503 aclTensor* acl_src = ggml_cann_create_tensor (src0);
0 commit comments