@@ -2154,86 +2154,129 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21542154
21552155 GGML_TENSOR_BINARY_OP_LOCALS
21562156
2157- // theta_scale arange, [0,1,...,ne00/2 - 1]
21582157 int64_t theta_scale_length = ne00 / 2 ;
2159- ggml_cann_pool_alloc theta_scale_allocator (ctx.pool (),
2160- theta_scale_length * sizeof (float_t ));
2161- void * theta_scale_buffer = theta_scale_allocator.get ();
21622158 int64_t theta_scale_ne[] = {theta_scale_length, 1 , 1 , 1 };
21632159 size_t theta_scale_nb[] = {sizeof (float_t ), sizeof (float_t ), sizeof (float_t ),
21642160 theta_scale_length * sizeof (float_t )};
21652161
2166- aclTensor* acl_theta_scale_tensor =
2167- ggml_cann_create_tensor (theta_scale_buffer, ACL_FLOAT, sizeof (float_t ),
2168- theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2169- float start = 0 ;
2170- float step = 1 ;
2171- float stop = ne00 / 2 ;
2172- float n_elements = ne00 / 2 ;
2173- aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2174-
2175- // power
2176- aclScalar* acl_theta_scale = aclCreateScalar (&theta_scale, aclDataType::ACL_FLOAT);
2177- GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2178- acl_theta_scale_tensor);
2179-
2180- // freq_scale
2181- if (freq_scale != 1 ) {
2182- aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
2183- }
2184-
2185- // freq_factors
2186- if (src2) {
2187- aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2188- src2->data , ggml_cann_type_mapping (src2->type ),
2189- ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2190- aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2191- ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2192- }
2193-
2194- // position
21952162 GGML_ASSERT (src1->type == GGML_TYPE_I32);
21962163 int64_t position_length = src1->ne [0 ];
21972164 int64_t position_ne[] = {1 , 1 , position_length, 1 };
21982165 size_t position_nb[] = {sizeof (int32_t ), sizeof (int32_t ), sizeof (int32_t ),
21992166 sizeof (int32_t ) * position_length};
2200- aclTensor* acl_position_tensor = ggml_cann_create_tensor (
2201- src1->data , ggml_cann_type_mapping (src1->type ),
2202- ggml_type_size (src1->type ), position_ne, position_nb, GGML_MAX_DIMS);
2203-
2204- // power * position
2205- int64_t theta_length = theta_scale_length * position_length;
2206- ggml_cann_pool_alloc theta_allocator (ctx.pool (),
2207- theta_length * sizeof (float_t ));
2208- void * theta_buffer = theta_allocator.get ();
2167+
22092168 int64_t theta_ne[] = {theta_scale_length, 1 , position_length, 1 };
22102169 size_t theta_nb[GGML_MAX_DIMS];
22112170 theta_nb[0 ] = sizeof (float_t );
22122171 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
22132172 theta_nb[i] = theta_nb[i - 1 ] * theta_ne[i - 1 ];
22142173 }
2215- aclTensor* acl_theta_tensor =
2216- ggml_cann_create_tensor (theta_buffer, ACL_FLOAT, sizeof (float_t ),
2217- theta_ne, theta_nb, GGML_MAX_DIMS);
2218- aclnn_mul (ctx, acl_position_tensor, acl_theta_scale_tensor,
2219- acl_theta_tensor);
2220-
2221- // sin/cos
2222- ggml_cann_pool_alloc sin_allocator (ctx.pool (),
2223- theta_length * sizeof (float_t ));
2224- void * sin_buffer = sin_allocator.get ();
2225- aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2226- sin_buffer, ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2227- GGML_MAX_DIMS, ACL_FORMAT_ND);
2228- aclnn_sin (ctx, acl_theta_tensor, acl_sin_tensor);
22292174
2230- ggml_cann_pool_alloc cos_allocator (ctx.pool (),
2231- theta_length * sizeof (float_t ));
2232- void * cos_buffer = cos_allocator.get ();
2175+ bool is_q = (std::strncmp (dst->name , " Qcur-" , 5 ) == 0 );
2176+ bool is_k = (std::strncmp (dst->name , " Kcur-" , 5 ) == 0 );
2177+
2178+ // used for accuracy testing
2179+ bool is_attention = is_q || is_k;
2180+
2181+ if (ctx.init_ptr == nullptr || !is_attention) {
2182+ // theta_scale arange, [0,1,...,ne00/2 - 1]
2183+ if (ctx.init_ptr != nullptr ){
2184+ ACL_CHECK (aclrtFree (ctx.init_ptr ));
2185+ }
2186+ ACL_CHECK (aclrtMalloc (&ctx.init_ptr , theta_scale_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2187+
2188+ aclTensor* acl_theta_scale_tensor =
2189+ ggml_cann_create_tensor (ctx.init_ptr , ACL_FLOAT, sizeof (float_t ),
2190+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2191+ float start = 0 ;
2192+ float step = 1 ;
2193+ float stop = ne00 / 2 ;
2194+ float n_elements = ne00 / 2 ;
2195+ aclnn_arange (ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
2196+
2197+ // power
2198+ aclScalar* acl_theta_scale = aclCreateScalar (&theta_scale, aclDataType::ACL_FLOAT);
2199+ GGML_CANN_CALL_ACLNN_OP (ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
2200+ acl_theta_scale_tensor);
2201+
2202+ // freq_scale
2203+ if (freq_scale != 1 ) {
2204+ aclnn_muls (ctx, acl_theta_scale_tensor, freq_scale, nullptr , true );
2205+ }
2206+
2207+ // freq_factors
2208+ if (src2) {
2209+ aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor (
2210+ src2->data , ggml_cann_type_mapping (src2->type ),
2211+ ggml_type_size (src2->type ), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2212+ aclnn_div (ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
2213+ ggml_cann_release_resources (ctx, acl_freq_factors_tensor);
2214+ }
2215+ // release
2216+ ggml_cann_release_resources (ctx, acl_theta_scale_tensor,acl_theta_scale);
2217+ }
2218+
2219+ if (ctx.sin_ptr == nullptr ) {
2220+ int64_t theta_length = theta_scale_length * ctx.max_prompt_length ;
2221+ ACL_CHECK (aclrtMalloc (&ctx.sin_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2222+ ACL_CHECK (aclrtMalloc (&ctx.cos_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2223+ }
2224+ if (position_length > ctx.max_prompt_length ) {
2225+ ctx.max_prompt_length = position_length;
2226+ int64_t theta_length = theta_scale_length * ctx.max_prompt_length ;
2227+ ACL_CHECK (aclrtFree (ctx.sin_ptr ));
2228+ ACL_CHECK (aclrtFree (ctx.cos_ptr ));
2229+ ACL_CHECK (aclrtMalloc (&ctx.sin_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2230+ ACL_CHECK (aclrtMalloc (&ctx.cos_ptr , theta_length * sizeof (float_t ), ACL_MEM_MALLOC_HUGE_FIRST));
2231+ }
2232+
2233+ bool is_fisrt_layer = (std::strncmp (dst->name , " Qcur-0" , GGML_MAX_NAME) == 0 );
2234+
2235+ if (is_fisrt_layer || !is_attention) {
2236+
2237+ aclTensor* acl_theta_scale_tensor =
2238+ ggml_cann_create_tensor (ctx.init_ptr , ACL_FLOAT, sizeof (float_t ),
2239+ theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
2240+
2241+ // position
2242+ aclTensor* acl_position_tensor = ggml_cann_create_tensor (
2243+ src1->data , ggml_cann_type_mapping (src1->type ),
2244+ ggml_type_size (src1->type ), position_ne, position_nb, GGML_MAX_DIMS);
2245+
2246+ // power * position
2247+ int64_t theta_length = theta_scale_length * position_length;
2248+ ggml_cann_pool_alloc theta_allocator (ctx.pool (),
2249+ theta_length * sizeof (float_t ));
2250+ void * theta_buffer = theta_allocator.get ();
2251+
2252+ aclTensor* acl_theta_tensor =
2253+ ggml_cann_create_tensor (theta_buffer, ACL_FLOAT, sizeof (float_t ),
2254+ theta_ne, theta_nb, GGML_MAX_DIMS);
2255+ aclnn_mul (ctx, acl_position_tensor, acl_theta_scale_tensor,
2256+ acl_theta_tensor);
2257+
2258+ // sin/cos
2259+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2260+ ctx.sin_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2261+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2262+ aclnn_sin (ctx, acl_theta_tensor, acl_sin_tensor);
2263+
2264+ aclTensor* acl_cos_tensor = ggml_cann_create_tensor (
2265+ ctx.cos_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2266+ GGML_MAX_DIMS, ACL_FORMAT_ND);
2267+ aclnn_cos (ctx, acl_theta_tensor, acl_cos_tensor);
2268+
2269+ // release
2270+ ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
2271+ acl_theta_tensor, acl_sin_tensor, acl_cos_tensor);
2272+ }
2273+
2274+ aclTensor* acl_sin_tensor = ggml_cann_create_tensor (
2275+ ctx.sin_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2276+ GGML_MAX_DIMS, ACL_FORMAT_ND);
22332277 aclTensor* acl_cos_tensor = ggml_cann_create_tensor (
2234- cos_buffer, ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2235- GGML_MAX_DIMS, ACL_FORMAT_ND);
2236- aclnn_cos (ctx, acl_theta_tensor, acl_cos_tensor);
2278+ ctx.cos_ptr , ACL_FLOAT, sizeof (float_t ), theta_ne, theta_nb,
2279+ GGML_MAX_DIMS, ACL_FORMAT_ND);
22372280
22382281 // attn_factor
22392282 if (attn_factor != 1 ) {
@@ -2257,8 +2300,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
22572300 }
22582301
22592302 // release
2260- ggml_cann_release_resources (ctx, acl_theta_scale_tensor, acl_position_tensor,
2261- acl_theta_tensor, acl_sin_tensor, acl_cos_tensor, acl_theta_scale);
2303+ ggml_cann_release_resources (ctx, acl_sin_tensor, acl_cos_tensor);
22622304}
22632305
22642306#ifdef __cplusplus
0 commit comments