@@ -964,8 +964,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
964964 }
965965 aclTensor* acl_gamma = get_f32_cache_acl_tensor (
966966 ctx,
967- &ctx.f32_one_cache ,
968- ctx.f32_one_cache_element ,
967+ &ctx.rms_norm_one_tensor_cache . cache ,
968+ ctx.rms_norm_one_tensor_cache . size ,
969969 src->ne ,
970970 acl_gamma_nb,
971971 1 , // dims
@@ -980,8 +980,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
980980 }
981981 aclTensor* acl_rstd = get_f32_cache_acl_tensor (
982982 ctx,
983- &ctx.f32_zero_cache ,
984- ctx.f32_zero_cache_element ,
983+ &ctx.rms_norm_zero_tensor_cache . cache ,
984+ ctx.rms_norm_zero_tensor_cache . size ,
985985 src->ne ,
986986 acl_rstd_nb,
987987 GGML_MAX_DIMS,
@@ -2249,7 +2249,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
22492249 * 6. Expand sin/cos values by repeat or repeat_interleave depending
22502250 * on whether @param is_neox is enabled.
22512251 * 7. Store the computed values into persistent buffers
2252- * (ctx.rope_sin_ptr / ctx.rope_cos_ptr ).
2252+ * (ctx.rope_cache.sin_cache / ctx.rope_cache.cos_cache ).
22532253 *
22542254 * @param ctx The CANN backend context, holding memory pool,
22552255 * stream, and persistent buffers for rope init/cache.
@@ -2266,25 +2266,30 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22662266 float attn_factor, bool is_neox) {
22672267 // int sin/cos cache, cache has different repeat method depond on
22682268 // @param.is_neox
2269- bool is_q = (std::strncmp (dst->name , " Qcur-" , 5 ) == 0 );
2270- bool is_k = (std::strncmp (dst->name , " Kcur-" , 5 ) == 0 );
2271-
2272- // used for accuracy testing
2273- bool is_attention = is_q || is_k;
2274-
2275- // just compute in first layer in attention
2276- bool is_fisrt_layer = (std::strncmp (dst->name , " Qcur-0" , GGML_MAX_NAME) == 0 );
2277- if (is_attention && !is_fisrt_layer) {
2278- return ;
2279- }
22802269
22812270 ggml_tensor* src0 = dst->src [0 ]; // input
22822271 ggml_tensor* src1 = dst->src [1 ]; // position
22832272 ggml_tensor* src2 = dst->src [2 ]; // freq_factors
22842273
2285- GGML_TENSOR_BINARY_OP_LOCALS
2274+ // get first layer in current device.
2275+ int layer = 0 ;
2276+ const char * dash = std::strchr (dst->name , ' -' );
2277+ if (dash) {
2278+ layer = std::strtol (dash + 1 , nullptr , 10 );
2279+ }
2280+
2281+ // remember the first layer.
2282+ if (ctx.rope_cache .first_layer == -1 )
2283+ ctx.rope_cache .first_layer = layer;
22862284
2287- int64_t theta_scale_length = ne00 / 2 ;
2285+ // only init cache when freq_factors is not null or first layer.
2286+ // dash == nullptr means we are in test-backend-ops
2287+ if (dash != nullptr && src2 == nullptr && layer != ctx.rope_cache .first_layer ) {
2288+ // use cache.
2289+ return ;
2290+ }
2291+
2292+ int64_t theta_scale_length = src0->ne [0 ] / 2 ;
22882293 int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
22892294 size_t theta_scale_nb[] = {sizeof (float_t ), sizeof (float_t ), sizeof (float_t ),
22902295 theta_scale_length * sizeof (float_t )};
@@ -2302,21 +2307,24 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23022307 theta_nb[i] = theta_nb[i - 1 ] * theta_ne[i - 1 ];
23032308 }
23042309
2310+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2311+ aclTensor* acl_theta_scale_tensor = nullptr ;
23052312 // init theta scale, just one time
2306- if (ctx. rope_init_ptr == nullptr || !is_attention) {
2307- // theta_scale arange, [0,1,...,ne00/2 - 1]
2308- if (ctx.rope_init_ptr != nullptr ){
2309- ACL_CHECK (aclrtFree (ctx.rope_init_ptr ));
2313+ // dash == nullptr means we are in test-backend-ops
2314+ if (ctx. rope_cache . theta_scale_cache == nullptr || dash == nullptr ) {
2315+ if (ctx.rope_cache . theta_scale_cache != nullptr ) {
2316+ ACL_CHECK (aclrtFree (ctx.rope_cache . theta_scale_cache ));
23102317 }
2311- ACL_CHECK (aclrtMalloc (&ctx.rope_init_ptr , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2318+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .theta_scale_cache , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2319+
2320+ acl_theta_scale_tensor =
2321+ ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
2322+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23122323
2313- aclTensor* acl_theta_scale_tensor =
2314- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2315- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
23162324 float start = 0 ;
23172325 float step = 1 ;
2318- float stop = ne00 / 2 ;
2319- float n_elements = ne00 / 2 ;
2326+ float stop = src0-> ne [ 0 ] / 2 ;
2327+ float n_elements = src0-> ne [ 0 ] / 2 ;
23202328 aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
23212329
23222330 // power
@@ -2328,35 +2336,37 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23282336 if (freq_scale != 1 ) {
23292337 aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
23302338 }
2339+ ggml_cann_release_resources (ctx, acl_theta_scale);
2340+ } else {
2341+ // use cache
2342+ acl_theta_scale_tensor =
2343+ ggml_cann_create_tensor (ctx.rope_cache .theta_scale_cache , ACL_FLOAT, sizeof (float_t ),
2344+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2345+ }
23312346
2332- // freq_factors
2333- if (src2) {
2334- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2335- src2->data , ggml_cann_type_mapping (src2->type ),
2336- ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2337- aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2338- ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2339- }
2340- // release
2341- ggml_cann_release_resources (ctx, acl_theta_scale_tensor,acl_theta_scale);
2347+ // freq_factors
2348+ if (src2) {
2349+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2350+ src2->data , ggml_cann_type_mapping (src2->type ),
2351+ ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2352+ aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2353+ ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
23422354 }
23432355
23442356 // init sin_repeat && cos_repeat, one token just init in 0 layer
2345- if (position_length > ctx.max_prompt_length ) {
2346- ctx.max_prompt_length = position_length;
2347- int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2 ;
2348- if (ctx.rope_sin_ptr != nullptr ) {
2349- ACL_CHECK (aclrtFree (ctx.rope_sin_ptr ));
2350- ACL_CHECK (aclrtFree (ctx.rope_cos_ptr ));
2357+ if (position_length > ctx.rope_cache .position_length ) {
2358+ ctx.rope_cache .position_length = position_length;
2359+ if (ctx.rope_cache .sin_cache != nullptr ) {
2360+ ACL_CHECK (aclrtFree (ctx.rope_cache .sin_cache ));
2361+ }
2362+ if (ctx.rope_cache .cos_cache != nullptr ) {
2363+ ACL_CHECK (aclrtFree (ctx.rope_cache .cos_cache ));
23512364 }
2352- ACL_CHECK (aclrtMalloc (&ctx.rope_sin_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2353- ACL_CHECK (aclrtMalloc (&ctx.rope_cos_ptr , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2365+ int64_t repeat_theta_length = theta_scale_length * position_length * 2 ;
2366+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .sin_cache , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2367+ ACL_CHECK (aclrtMalloc (&ctx.rope_cache .cos_cache , repeat_theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
23542368 }
23552369
2356- aclTensor* acl_theta_scale_tensor =
2357- ggml_cann_create_tensor (ctx.rope_init_ptr , ACL_FLOAT, sizeof (float_t ),
2358- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2359-
23602370 // position
23612371 aclTensor* acl_position_tensor = ggml_cann_create_tensor (
23622372 src1->data , ggml_cann_type_mapping (src1->type ),
@@ -2397,17 +2407,17 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
23972407 aclnn_muls (ctx, acl_cos_tensor, attn_factor, nullptr , true );
23982408 }
23992409
2400- int64_t sin_reshape_ne[4 ] = {ne00 , 1 , ne02 , 1 };
2410+ int64_t sin_reshape_ne[4 ] = {src0-> ne [ 0 ] , 1 , src0-> ne [ 2 ] , 1 };
24012411 size_t sin_reshape_nb[GGML_MAX_DIMS];
24022412 sin_reshape_nb[0 ] = sizeof (float_t );
24032413 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
24042414 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24052415 }
24062416 aclTensor* acl_sin_repeat_tensor =
2407- ggml_cann_create_tensor (ctx.rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2417+ ggml_cann_create_tensor (ctx.rope_cache . sin_cache , ACL_FLOAT, sizeof (float_t ),
24082418 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24092419 aclTensor* acl_cos_repeat_tensor =
2410- ggml_cann_create_tensor (ctx.rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2420+ ggml_cann_create_tensor (ctx.rope_cache . cos_cache , ACL_FLOAT, sizeof (float_t ),
24112421 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24122422
24132423 // repeat
@@ -2491,10 +2501,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
24912501 sin_reshape_nb[i] = sin_reshape_nb[i - 1 ] * sin_reshape_ne[i - 1 ];
24922502 }
24932503 aclTensor* acl_sin_reshape_tensor =
2494- ggml_cann_create_tensor (ctx.rope_sin_ptr , ACL_FLOAT, sizeof (float_t ),
2504+ ggml_cann_create_tensor (ctx.rope_cache . sin_cache , ACL_FLOAT, sizeof (float_t ),
24952505 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24962506 aclTensor* acl_cos_reshape_tensor =
2497- ggml_cann_create_tensor (ctx.rope_cos_ptr , ACL_FLOAT, sizeof (float_t ),
2507+ ggml_cann_create_tensor (ctx.rope_cache . cos_cache , ACL_FLOAT, sizeof (float_t ),
24982508 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
24992509
25002510 aclTensor* acl_src = ggml_cann_create_tensor (src0);
0 commit comments