@@ -2154,87 +2154,128 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21542154
21552155 GGML_TENSOR_BINARY_OP_LOCALS
21562156
2157- // theta_scale arange, [0,1,...,ne00/2 - 1]
21582157 int64_t theta_scale_length = ne00 / 2 ;
2159- ggml_cann_pool_alloc theta_scale_allocator (ctx.pool (),
2160- theta_scale_length * sizeof (float_t ));
2161- void * theta_scale_buffer = theta_scale_allocator.get ();
21622158 int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
21632159 size_t theta_scale_nb[] = {sizeof (float_t ), sizeof (float_t ), sizeof (float_t ),
21642160 theta_scale_length * sizeof (float_t )};
21652161
2166- aclTensor* acl_theta_scale_tensor =
2167- ggml_cann_create_tensor (theta_scale_buffer, ACL_FLOAT, sizeof (float_t ),
2168- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2169- float start = 0 ;
2170- float step = 1 ;
2171- float stop = ne00 / 2 ;
2172- float n_elements = ne00 / 2 ;
2173- aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2174-
2175- // power
2176- aclScalar* acl_theta_scale = aclCreateScalar (&theta_scale, aclDataType::ACL_FLOAT);
2177- GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2178- acl_theta_scale_tensor);
2179-
2180- // freq_scale
2181- if (freq_scale != 1 ) {
2182- aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
2183- }
2184-
2185- // freq_factors
2186- if (src2) {
2187- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2188- src2->data , ggml_cann_type_mapping (src2->type ),
2189- ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2190- aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2191- ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2192- }
2193-
2194- // position
21952162 GGML_ASSERT (src1->type == GGML_TYPE_I32);
21962163 int64_t position_length = src1->ne [0 ];
21972164 int64_t position_ne[] = {1 , 1 , position_length, 1 };
21982165 size_t position_nb[] = {sizeof (int32_t ), sizeof (int32_t ), sizeof (int32_t ),
21992166 sizeof (int32_t ) * position_length};
2200- aclTensor* acl_position_tensor = ggml_cann_create_tensor (
2201- src1->data , ggml_cann_type_mapping (src1->type ),
2202- ggml_type_size (src1->type ), position_ne, position_nb, GGML_MAX_DIMS);
2203-
2204- // power * position
2205- int64_t theta_length = theta_scale_length * position_length;
2206- ggml_cann_pool_alloc theta_allocator (ctx.pool (),
2207- theta_length * sizeof (float_t ));
2208- void * theta_buffer = theta_allocator.get ();
2167+
22092168 int64_t theta_ne[] = {theta_scale_length, 1 , position_length, 1 };
22102169 size_t theta_nb[GGML_MAX_DIMS];
22112170 theta_nb[0 ] = sizeof (float_t );
22122171 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
22132172 theta_nb[i] = theta_nb[i - 1 ] * theta_ne[i - 1 ];
22142173 }
2215- aclTensor* acl_theta_tensor =
2216- ggml_cann_create_tensor (theta_buffer, ACL_FLOAT, sizeof (float_t ),
2217- theta_ne, theta_nb, GGML_MAX_DIMS);
2218- aclnn_mul (ctx, acl_position_tensor, acl_theta_scale_tensor,
2219- acl_theta_tensor);
2220-
2221- // sin/cos
2222- ggml_cann_pool_alloc sin_allocator (ctx.pool (),
2223- theta_length * sizeof (float_t ));
2224- void * sin_buffer = sin_allocator.get ();
2225- aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2226- sin_buffer, ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2227- GGML_MAX_DIMS, ACL_FORMAT_ND);
2228- aclnn_sin (ctx, acl_theta_tensor, acl_sin_tensor);
22292174
2230- ggml_cann_pool_alloc cos_allocator (ctx.pool (),
2231- theta_length * sizeof (float_t ));
2232- void * cos_buffer = cos_allocator.get ();
2233- aclTensor* acl_cos_tensor = ggml_cann_create_tensor (
2234- cos_buffer, ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2235- GGML_MAX_DIMS, ACL_FORMAT_ND);
2236- aclnn_cos (ctx, acl_theta_tensor, acl_cos_tensor);
2175+ bool is_q = (std::strncmp (dst->name , " Qcur-" , 5 ) == 0 );
2176+ bool is_k = (std::strncmp (dst->name , " Kcur-" , 5 ) == 0 );
2177+ bool is_attention = is_q || is_k;
22372178
2179+ if (ctx.init_ptr == nullptr || !is_attention) {
2180+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2181+ if (ctx.init_ptr != nullptr ){
2182+ ACL_CHECK (aclrtFree (ctx.init_ptr ));
2183+ }
2184+ ACL_CHECK (aclrtMalloc (&ctx.init_ptr ,theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2185+
2186+ aclTensor* acl_theta_scale_tensor =
2187+ ggml_cann_create_tensor (ctx.init_ptr , ACL_FLOAT, sizeof (float_t ),
2188+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2189+ float start = 0 ;
2190+ float step = 1 ;
2191+ float stop = ne00 / 2 ;
2192+ float n_elements = ne00 / 2 ;
2193+ aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2194+
2195+ // power
2196+ aclScalar* acl_theta_scale = aclCreateScalar (&theta_scale, aclDataType::ACL_FLOAT);
2197+ GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2198+ acl_theta_scale_tensor);
2199+
2200+ // freq_scale
2201+ if (freq_scale != 1 ) {
2202+ aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
2203+ }
2204+
2205+ // freq_factors
2206+ if (src2) {
2207+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2208+ src2->data , ggml_cann_type_mapping (src2->type ),
2209+ ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2210+ aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2211+ ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2212+ }
2213+ // release
2214+ ggml_cann_release_resources (ctx, acl_theta_scale_tensor,acl_theta_scale);
2215+ }
2216+
2217+ if (ctx.sin_ptr == nullptr ) {
2218+ int64_t theta_length = theta_scale_length * ctx.max_position_length ;
2219+ ACL_CHECK (aclrtMalloc (&ctx.sin_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2220+ ACL_CHECK (aclrtMalloc (&ctx.cos_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2221+ }
2222+ if (position_length > ctx.max_position_length ) {
2223+ ctx.max_position_length = position_length;
2224+ int64_t theta_length = theta_scale_length * ctx.max_position_length ;
2225+ ACL_CHECK (aclrtFree (ctx.sin_ptr ));
2226+ ACL_CHECK (aclrtFree (ctx.cos_ptr ));
2227+ ACL_CHECK (aclrtMalloc (&ctx.sin_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2228+ ACL_CHECK (aclrtMalloc (&ctx.cos_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2229+ }
2230+
2231+ bool is_fisrt_layer = (std::strncmp (dst->name , " Qcur-0" , GGML_MAX_NAME) == 0 );
2232+
2233+ if (is_fisrt_layer || !is_attention) {
2234+
2235+ aclTensor* acl_theta_scale_tensor =
2236+ ggml_cann_create_tensor (ctx.init_ptr , ACL_FLOAT, sizeof (float_t ),
2237+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2238+
2239+ // position
2240+ aclTensor* acl_position_tensor = ggml_cann_create_tensor (
2241+ src1->data , ggml_cann_type_mapping (src1->type ),
2242+ ggml_type_size (src1->type ), position_ne, position_nb, GGML_MAX_DIMS);
2243+
2244+ // power * position
2245+ int64_t theta_length = theta_scale_length * position_length;
2246+ ggml_cann_pool_alloc theta_allocator (ctx.pool (),
2247+ theta_length * sizeof (float_t ));
2248+ void * theta_buffer = theta_allocator.get ();
2249+
2250+ aclTensor* acl_theta_tensor =
2251+ ggml_cann_create_tensor (theta_buffer, ACL_FLOAT, sizeof (float_t ),
2252+ theta_ne, theta_nb, GGML_MAX_DIMS);
2253+ aclnn_mul (ctx, acl_position_tensor, acl_theta_scale_tensor,
2254+ acl_theta_tensor);
2255+
2256+ // sin/cos
2257+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2258+ ctx.sin_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2259+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2260+ aclnn_sin (ctx, acl_theta_tensor, acl_sin_tensor);
2261+
2262+ aclTensor* acl_cos_tensor = ggml_cann_create_tensor (
2263+ ctx.cos_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2264+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2265+ aclnn_cos (ctx, acl_theta_tensor, acl_cos_tensor);
2266+
2267+ // release
2268+ ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
2269+ acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
2270+ }
2271+
2272+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2273+ ctx.sin_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2274+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2275+ aclTensor* acl_cos_tensor = ggml_cann_create_tensor (
2276+ ctx.cos_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2277+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2278+
22382279 // attn_factor
22392280 if (attn_factor != 1 ) {
22402281 aclnn_muls (ctx, acl_sin_tensor, attn_factor, nullptr , true );
@@ -2257,8 +2298,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22572298 }
22582299
22592300 // release
2260- ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
2261- acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
2301+ ggml_cann_release_resources (ctx, acl_sin_tensor, acl_cos_tensor);
22622302}
22632303
22642304#ifdef __cplusplus
0 commit comments