@@ -145,23 +145,6 @@ static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
145145 GGML_CANN_CALL_ACLNN_OP (Cast, acl_src, cast_data_type, acl_dst);
146146}
147147
148- /* *
149- * @brief Casts the elements of a tensor to a specified data type using the CANN backend.
150- *
151- * @details This function performs a type conversion on the elements of the input tensor `acl_src`
152- * and stores the results in the destination tensor `acl_dst`. The conversion type is
153- * determined based on the `dst` tensor's data type.
154- *
155- * @param ctx The context for the CANN backend operations.
156- * @param acl_src The source tensor whose elements will be cast.
157- * @param acl_dst The destination tensor that will store the casted elements.
158- * @param dst The ggml tensor specifying the target data type.
159- */
160- static void aclnn_cast (ggml_backend_cann_context& ctx, aclTensor* acl_src,
161- aclTensor* acl_dst, ggml_tensor* dst) {
162- aclnn_cast (ctx, acl_src, acl_dst, ggml_cann_type_mapping (dst->type ));
163- }
164-
165148void ggml_cann_repeat (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
166149 ggml_tensor* src = dst->src [0 ];
167150 GGML_ASSERT (ggml_can_repeat (src, dst));
@@ -768,7 +751,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
768751 if (dst->type == src0->type ) {
769752 cann_copy (ctx, acl_src, acl_dst);
770753 } else {
771- aclnn_cast (ctx, acl_src, acl_dst, dst);
754+ aclnn_cast (ctx, acl_src, acl_dst, ggml_cann_type_mapping ( dst-> type ) );
772755 }
773756 } else {
774757 if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst)) {
@@ -793,7 +776,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
793776 ggml_type_size (dst->type ), src0->ne , src_trans_nb,
794777 GGML_MAX_DIMS);
795778
796- aclnn_cast (ctx, acl_src, src_trans_tensor, dst);
779+ aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping ( dst-> type ) );
797780 size_t cpy_size = ggml_nbytes (dst);
798781 ACL_CHECK (aclrtMemcpyAsync (
799782 dst->data , cpy_size, src_trans_buffer, cpy_size,
@@ -815,7 +798,7 @@ void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
815798 ggml_type_size (dst->type ), src0->ne , src_trans_nb,
816799 GGML_MAX_DIMS);
817800
818- aclnn_cast (ctx, acl_src, src_trans_tensor, dst);
801+ aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping ( dst-> type ) );
819802
820803 size_t cpy_size = ggml_nbytes (dst);
821804 ACL_CHECK (aclrtMemcpyAsync (dst->data , cpy_size, src_trans_buffer,
@@ -1159,7 +1142,7 @@ void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
11591142 tmp_cast_buffer, ggml_cann_type_mapping (dst->type ),
11601143 ggml_type_size (dst->type ), tmp_im2col_ne, temp_cast_nb,
11611144 GGML_MAX_DIMS - 1 , ACL_FORMAT_ND);
1162- aclnn_cast (ctx, tmp_im2col_tensor, tmp_cast_tensor, dst);
1145+ aclnn_cast (ctx, tmp_im2col_tensor, tmp_cast_tensor, ggml_cann_type_mapping ( dst-> type ) );
11631146 }
11641147
11651148 // post-processing
@@ -1734,7 +1717,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
17341717 aclTensor* src_trans_tensor = ggml_cann_create_tensor (
17351718 src_trans_buffer, ACL_FLOAT, ggml_type_size (dst->type ),
17361719 src0->ne , src_trans_nb, GGML_MAX_DIMS);
1737- aclnn_cast (ctx, acl_src0, src_trans_tensor, dst);
1720+ aclnn_cast (ctx, acl_src0, src_trans_tensor, ggml_cann_type_mapping ( dst-> type ) );
17381721 aclnn_embedding_4d (ctx, src_trans_buffer, src0->ne ,
17391722 src_trans_nb, src1, dst);
17401723 ACL_CHECK (aclDestroyTensor (acl_src0));
@@ -2075,7 +2058,7 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
20752058 output_buffer, ACL_FLOAT16, output_elem_size, output_cast_ne,
20762059 output_cast_nb, GGML_MAX_DIMS);
20772060 aclTensor* acl_dst_tensor = ggml_cann_create_tensor (dst);
2078- aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, dst);
2061+ aclnn_cast (ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping ( dst-> type ) );
20792062
20802063 ACL_CHECK (aclDestroyTensor (acl_output_tensor));
20812064 ACL_CHECK (aclDestroyTensor (acl_dst_tensor));
@@ -2162,7 +2145,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
21622145
21632146 GGML_TENSOR_BINARY_OP_LOCALS
21642147
2165- // theta_scale arange, [0,1,...,ne0/2 ]
2148+ // theta_scale arange, [0,1,...,ne00/2 - 1 ]
21662149 int64_t theta_scale_length = ne00 / 2 ;
21672150 ggml_cann_pool_alloc theta_scale_allocator (ctx.pool (),
21682151 theta_scale_length * sizeof (float_t ));
@@ -2291,7 +2274,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
22912274 // TODO: use ascendc
22922275 // Only test with LLAMA model.
22932276 ggml_tensor* src0 = dst->src [0 ]; // input
2294- // ggml_tensor* src2 = dst->src[2]; // freq_factors, not used now.
22952277
22962278 // param
22972279 float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@@ -2345,7 +2327,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
23452327 ggml_cann_create_tensor (cos_buffer, ACL_FLOAT, sizeof (float_t ),
23462328 sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
23472329 aclnn_cache_init (ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
2348- theta_scale, freq_scale, attn_factor, is_neox);
2330+ theta_scale, freq_scale, attn_factor, is_neox);
23492331
23502332 aclTensor* acl_src = ggml_cann_create_tensor (src0);
23512333 aclTensor* acl_dst = ggml_cann_create_tensor (dst);
@@ -2522,46 +2504,52 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25222504 return ;
25232505#endif
25242506
2525- // src0 == GGML_TYPE_F16
2526- // TODO: optimization this `if` code
2527- if (src0->type == GGML_TYPE_F16) {
2528- ggml_cann_pool_alloc sin_final_allocator (
2529- ctx.pool (), src0->ne [0 ] * src0->ne [2 ] * ggml_type_size (src0->type ));
2530- ggml_cann_pool_alloc cos_final_allocator (
2531- ctx.pool (), src0->ne [0 ] * src0->ne [2 ] * ggml_type_size (src0->type ));
2532- void * sin_final_buffer = sin_final_allocator.get ();
2533- void * cos_final_buffer = cos_final_allocator.get ();
2534-
2535- int64_t sin_final_ne[4 ] = {src0->ne [0 ], 1 , src0->ne [2 ], 1 };
2536- size_t sin_final_nb[GGML_MAX_DIMS];
2537- sin_final_nb[0 ] = ggml_type_size (src0->type );
2538- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2539- sin_final_nb[i] = sin_final_nb[i - 1 ] * sin_final_ne[i - 1 ];
2507+ // ggml_mode = 0 --> aclnn_model = 1
2508+ int64_t acl_mode = mode == 0 ? 1 : mode;
2509+
2510+ switch (src0->type ) {
2511+ case GGML_TYPE_F32: {
2512+ GGML_CANN_CALL_ACLNN_OP (RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
2513+ acl_sin_reshape_tensor, acl_mode, acl_dst);
2514+ break ;
25402515 }
2541- aclTensor* acl_sin_final_tensor = ggml_cann_create_tensor (
2542- sin_final_buffer, ggml_cann_type_mapping (src0->type ),
2543- ggml_type_size (src0->type ), sin_final_ne, sin_final_nb,
2544- GGML_MAX_DIMS);
2545- aclTensor* acl_cos_final_tensor = ggml_cann_create_tensor (
2546- cos_final_buffer, ggml_cann_type_mapping (src0->type ),
2547- ggml_type_size (src0->type ), sin_final_ne, sin_final_nb,
2548- GGML_MAX_DIMS);
2549-
2550- aclnn_cast (ctx, acl_sin_reshape_tensor, acl_sin_final_tensor, dst);
2551- aclnn_cast (ctx, acl_cos_reshape_tensor, acl_cos_final_tensor, dst);
2552- ACL_CHECK (aclDestroyTensor (acl_cos_reshape_tensor));
2553- ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
2554- acl_sin_reshape_tensor = acl_sin_final_tensor;
2555- acl_cos_reshape_tensor = acl_cos_final_tensor;
2556- }
2557-
2558- int acl_mode = mode;
2559- if (mode == 0 ) {
2560- acl_mode = 1 ;
2516+ case GGML_TYPE_F16: {
2517+ ggml_cann_pool_alloc src_trans_allocator (
2518+ ctx.pool (), ggml_nelements (src0) * sizeof (float ));
2519+ void * src_trans_buffer = src_trans_allocator.get ();
2520+ ggml_cann_pool_alloc dst_trans_allocator (
2521+ ctx.pool (), ggml_nelements (dst) * sizeof (float ));
2522+ void * dst_trans_buffer = dst_trans_allocator.get ();
2523+
2524+ size_t src_trans_nb[GGML_MAX_DIMS];
2525+ src_trans_nb[0 ] = sizeof (float );
2526+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
2527+ src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
2528+ }
2529+
2530+ aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor (
2531+ src_trans_buffer, ACL_FLOAT, sizeof (float ), src0->ne , src_trans_nb,
2532+ GGML_MAX_DIMS);
2533+
2534+ aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor (
2535+ dst_trans_buffer, ACL_FLOAT, sizeof (float ), dst->ne , src_trans_nb,
2536+ GGML_MAX_DIMS);
2537+
2538+ aclnn_cast (ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
2539+
2540+ GGML_CANN_CALL_ACLNN_OP (RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor,
2541+ acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor);
2542+
2543+ aclnn_cast (ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
2544+
2545+ ACL_CHECK (aclDestroyTensor (acl_src_trans_tensor));
2546+ ACL_CHECK (aclDestroyTensor (acl_dst_trans_tensor));
2547+ break ;
2548+ }
2549+ default :
2550+ GGML_ABORT (" Unsupported tensor type for GGML_OP_ROPE" );
2551+ break ;
25612552 }
2562-
2563- GGML_CANN_CALL_ACLNN_OP (RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
2564- acl_sin_reshape_tensor, acl_mode, acl_dst);
25652553 ACL_CHECK (aclDestroyTensor (acl_src));
25662554 ACL_CHECK (aclDestroyTensor (acl_cos_reshape_tensor));
25672555 ACL_CHECK (aclDestroyTensor (acl_sin_reshape_tensor));
0 commit comments