@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
964
964
}
965
965
aclTensor* acl_gamma = get_f32_cache_acl_tensor (
966
966
ctx,
967
- &ctx.f32_one_cache ,
968
- ctx.f32_one_cache_element ,
967
+ &ctx.rms_norm_one_tensor_cache . cache ,
968
+ ctx.rms_norm_one_tensor_cache . size ,
969
969
src->ne ,
970
970
acl_gamma_nb,
971
971
1 , // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
980
980
}
981
981
aclTensor* acl_rstd = get_f32_cache_acl_tensor (
982
982
ctx,
983
- &ctx.f32_zero_cache ,
984
- ctx.f32_zero_cache_element ,
983
+ &ctx.rms_norm_zero_tensor_cache . cache ,
984
+ ctx.rms_norm_zero_tensor_cache . size ,
985
985
src->ne ,
986
986
acl_rstd_nb,
987
987
GGML_MAX_DIMS,
@@ -2248,43 +2248,31 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
2248
2248
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
2249
2249
* 6. Expand sin/cos values by repeat or repeat_interleave depending
2250
2250
* on whether @param is_neox is enabled.
2251
- * 7. Store the computed values into persistent buffers
2252
- * (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
2253
- *
2254
- * @param ctx The CANN backend context, holding memory pool,
2255
- * stream, and persistent buffers for rope init/cache .
2256
- * @param dst The destination ggml_tensor whose computation
2257
- * depends on the cached RoPE values (usually Qcur/Kcur) .
2258
- * @param theta_scale Scalar exponent base for computing theta scale values.
2259
- * @param freq_scale Frequency scaling factor, applied to theta scale.
2260
- * @param attn_factor Attention scaling factor, applied to sin/cos.
2261
- * @param is_neox Whether to use Neox-style repeat strategy
2262
- * (dim expansion vs repeat_interleave).
2251
+ *
2252
+ * @param ctx The CANN backend context, holding memory pool,
2253
+ * stream, and persistent buffers for rope init/cache.
2254
+ * @param dst The destination ggml_tensor whose computation
2255
+ * depends on the RoPE values (usually Qcur/Kcur) .
2256
+ * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
2257
+ * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
2258
+ * @param theta_scale Scalar exponent base for computing theta scale values.
2259
+ * @param freq_scale Frequency scaling factor, applied to theta scale.
2260
+ * @param attn_factor Attention scaling factor, applied to sin/cos.
2261
+ * @param is_neox Whether to use Neox-style repeat strategy
2262
+ * (dim expansion vs repeat_interleave).
2263
2263
*/
2264
2264
static void aclnn_cache_init (ggml_backend_cann_context& ctx, ggml_tensor* dst,
2265
+ void * sin_tensor_buffer, void * cos_tensor_buffer,
2265
2266
float theta_scale, float freq_scale,
2266
2267
float attn_factor, bool is_neox) {
2267
2268
// int sin/cos cache, cache has different repeat method depond on
2268
2269
// @param.is_neox
2269
- bool is_q = (std::strncmp (dst->name , " Qcur-" , 5 ) == 0 );
2270
- bool is_k = (std::strncmp (dst->name , " Kcur-" , 5 ) == 0 );
2271
-
2272
- // used for accuracy testing
2273
- bool is_attention = is_q || is_k;
2274
-
2275
- // just compute in first layer in attention
2276
- bool is_fisrt_layer = (std::strncmp (dst->name , " Qcur-0" , GGML_MAX_NAME) == 0 );
2277
- if (is_attention && !is_fisrt_layer) {
2278
- return ;
2279
- }
2280
2270
2281
2271
ggml_tensor* src0 = dst->src [0 ]; // input
2282
2272
ggml_tensor* src1 = dst->src [1 ]; // position
2283
2273
ggml_tensor* src2 = dst->src [2 ]; // freq_factors
2284
2274
2285
- GGML_TENSOR_BINARY_OP_LOCALS
2286
-
2287
- int64_t theta_scale_length = ne00 / 2 ;
2275
+ int64_t theta_scale_length = src0->ne [0 ] / 2 ;
2288
2276
int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
2289
2277
size_t theta_scale_nb[] = {sizeof (float_t ), sizeof (float_t ), sizeof (float_t ),
2290
2278
theta_scale_length * sizeof (float_t )};
@@ -2302,21 +2290,32 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2302
2290
theta_nb[i] = theta_nb[i - 1 ] * theta_ne[i - 1 ];
2303
2291
}
2304
2292
2305
- // init theta scale, just one time
2306
- if (ctx.rope_init_ptr == nullptr || !is_attention) {
2307
- // theta_scale arange, [0,1,...,ne00/2 - 1]
2308
- if (ctx.rope_init_ptr != nullptr ){
2309
- ACL_CHECK (aclrtFree (ctx.rope_init_ptr ));
2293
+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2294
+ aclTensor* acl_theta_scale_tensor = nullptr ;
2295
+ // cache theta scale
2296
+ if (ctx.rope_cache .theta_scale_length != theta_scale_length ||
2297
+ // theta_scale and freq_scale should not change during the current token inference process,
2298
+ // so we can directly use == here instead of comparing the absolute difference.
2299
+ ctx.rope_cache .theta_scale != theta_scale ||
2300
+ ctx.rope_cache .freq_scale != freq_scale) {
2301
+
2302
+ ctx.rope_cache .theta_scale_length = theta_scale_length;
2303
+ ctx.rope_cache .theta_scale = theta_scale;
2304
+ ctx.rope_cache .freq_scale = freq_scale;
2305
+
2306
+ if (ctx.rope_cache .theta_scale_cache != nullptr ) {
2307
+ ACL_CHECK (aclrtFree (ctx.rope_cache .theta_scale_cache ));
2310
2308
}
2311
- ACL_CHECK (aclrtMalloc (&ctx.rope_init_ptr , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2309
+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache . theta_scale_cache , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2312
2310
2313
- aclTensor* acl_theta_scale_tensor =
2314
- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2311
+ acl_theta_scale_tensor =
2312
+ ggml_cann_create_tensor (ctx.rope_cache . theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
2315
2313
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2314
+
2316
2315
float start = 0 ;
2317
2316
float step = 1 ;
2318
- float stop = ne00 / 2 ;
2319
- float n_elements = ne00 / 2 ;
2317
+ float stop = theta_scale_length ;
2318
+ float n_elements = theta_scale_length ;
2320
2319
aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2321
2320
2322
2321
// power
@@ -2328,35 +2327,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2328
2327
if (freq_scale != 1 ) {
2329
2328
aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
2330
2329
}
2331
-
2332
- // freq_factors
2333
- if (src2) {
2334
- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2335
- src2->data , ggml_cann_type_mapping (src2->type ),
2336
- ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2337
- aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2338
- ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2339
- }
2340
- // release
2341
- ggml_cann_release_resources (ctx, acl_theta_scale_tensor,acl_theta_scale);
2330
+ ggml_cann_release_resources (ctx, acl_theta_scale);
2331
+ } else {
2332
+ // use cache
2333
+ acl_theta_scale_tensor =
2334
+ ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
2335
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2342
2336
}
2343
2337
2344
- // init sin_repeat && cos_repeat, one token just init in 0 layer
2345
- if (position_length > ctx.max_prompt_length ) {
2346
- ctx.max_prompt_length = position_length;
2347
- int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2 ;
2348
- if (ctx.rope_sin_ptr != nullptr ) {
2349
- ACL_CHECK (aclrtFree (ctx.rope_sin_ptr ));
2350
- ACL_CHECK (aclrtFree (ctx.rope_cos_ptr ));
2351
- }
2352
- ACL_CHECK (aclrtMalloc (&ctx.rope_sin_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2353
- ACL_CHECK (aclrtMalloc (&ctx.rope_cos_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2338
+ ggml_cann_pool_alloc freq_fac_res_allocator (ctx.pool ());
2339
+ // freq_factors
2340
+ if (src2) {
2341
+ freq_fac_res_allocator.alloc (theta_scale_length * sizeof (float_t ));
2342
+ void * freq_fac_res_ptr = freq_fac_res_allocator.get ();
2343
+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2344
+ src2->data , ggml_cann_type_mapping (src2->type ),
2345
+ ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2346
+ aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor (
2347
+ freq_fac_res_ptr, ACL_FLOAT, sizeof (float_t ),
2348
+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2349
+ aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2350
+ std::swap (acl_theta_scale_tensor, acl_freq_fac_res_tensor);
2351
+ ggml_cann_release_resources (ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
2354
2352
}
2355
2353
2356
- aclTensor* acl_theta_scale_tensor =
2357
- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2358
- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2359
-
2360
2354
// position
2361
2355
aclTensor* acl_position_tensor = ggml_cann_create_tensor (
2362
2356
src1->data , ggml_cann_type_mapping (src1->type ),
@@ -2397,17 +2391,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
2397
2391
aclnn_muls (ctx, acl_cos_tensor, attn_factor, nullptr , true );
2398
2392
}
2399
2393
2400
- int64_t sin_reshape_ne[4 ] = {ne00 , 1 , ne02 , 1 };
2394
+ int64_t sin_reshape_ne[4 ] = {src0-> ne [ 0 ] , 1 , src0-> ne [ 2 ] , 1 };
2401
2395
size_t sin_reshape_nb[GGML_MAX_DIMS];
2402
2396
sin_reshape_nb[0 ] = sizeof (float_t );
2403
2397
for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2404
2398
sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
2405
2399
}
2406
2400
aclTensor* acl_sin_repeat_tensor =
2407
- ggml_cann_create_tensor (ctx. rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2401
+ ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
2408
2402
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2409
2403
aclTensor* acl_cos_repeat_tensor =
2410
- ggml_cann_create_tensor (ctx. rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2404
+ ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
2411
2405
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2412
2406
2413
2407
// repeat
@@ -2449,6 +2443,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2449
2443
// TODO: use ascendc
2450
2444
// Only test with LLAMA model.
2451
2445
ggml_tensor* src0 = dst->src [0 ]; // input
2446
+ ggml_tensor* src1 = dst->src [1 ];
2452
2447
2453
2448
// param
2454
2449
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2481,8 +2476,16 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2481
2476
2482
2477
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
2483
2478
2479
+ // sin/cos tensor length.
2480
+ int64_t repeat_theta_length = src0->ne [0 ] * src1->ne [0 ];
2481
+ ggml_cann_pool_alloc sin_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2482
+ ggml_cann_pool_alloc cos_tensor_allocator (ctx.pool (), repeat_theta_length * sizeof (float ));
2483
+ void *sin_tensor_buffer = sin_tensor_allocator.get ();
2484
+ void *cos_tensor_buffer = cos_tensor_allocator.get ();
2485
+
2484
2486
// init ctx.rope_cos/rope_sin cache
2485
- aclnn_cache_init (ctx, dst, theta_scale, freq_scale, attn_factor, is_neox);
2487
+ aclnn_cache_init (ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
2488
+ theta_scale, freq_scale, attn_factor, is_neox);
2486
2489
2487
2490
int64_t sin_reshape_ne[4 ] = {ne00, 1 , ne02, 1 };
2488
2491
size_t sin_reshape_nb[GGML_MAX_DIMS];
@@ -2491,10 +2494,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
2491
2494
sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
2492
2495
}
2493
2496
aclTensor* acl_sin_reshape_tensor =
2494
- ggml_cann_create_tensor (ctx. rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2497
+ ggml_cann_create_tensor (sin_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
2495
2498
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2496
2499
aclTensor* acl_cos_reshape_tensor =
2497
- ggml_cann_create_tensor (ctx. rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2500
+ ggml_cann_create_tensor (cos_tensor_buffer , ACL_FLOAT, sizeof (float_t ),
2498
2501
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
2499
2502
2500
2503
aclTensor* acl_src = ggml_cann_create_tensor (src0);
0 commit comments