@@ -2506,7 +2506,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25062506
25072507 // ggml_mode = 0 --> aclnn_model = 1
25082508 int64_t acl_mode = mode == 0 ? 1 : mode;
2509-
2509+
25102510 switch (src0->type ) {
25112511 case GGML_TYPE_F32: {
25122512 GGML_CANN_CALL_ACLNN_OP (RotaryPositionEmbedding, acl_src, acl_cos_reshape_tensor,
@@ -2520,28 +2520,27 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
25202520 ggml_cann_pool_alloc dst_trans_allocator (
25212521 ctx.pool (), ggml_nelements (dst) * sizeof (float ));
25222522 void * dst_trans_buffer = dst_trans_allocator.get ();
2523-
2523+
25242524 size_t src_trans_nb[GGML_MAX_DIMS];
25252525 src_trans_nb[0 ] = sizeof (float );
25262526 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
25272527 src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
25282528 }
2529-
2529+
25302530 aclTensor* acl_src_trans_tensor = ggml_cann_create_tensor (
25312531 src_trans_buffer, ACL_FLOAT, sizeof (float ), src0->ne , src_trans_nb,
25322532 GGML_MAX_DIMS);
2533-
25342533 aclTensor* acl_dst_trans_tensor = ggml_cann_create_tensor (
25352534 dst_trans_buffer, ACL_FLOAT, sizeof (float ), dst->ne , src_trans_nb,
25362535 GGML_MAX_DIMS);
2537-
2536+
25382537 aclnn_cast (ctx, acl_src, acl_src_trans_tensor, ACL_FLOAT);
2539-
2538+
25402539 GGML_CANN_CALL_ACLNN_OP (RotaryPositionEmbedding, acl_src_trans_tensor, acl_cos_reshape_tensor,
25412540 acl_sin_reshape_tensor, acl_mode, acl_dst_trans_tensor);
2542-
2541+
25432542 aclnn_cast (ctx, acl_dst_trans_tensor, acl_dst, ACL_FLOAT16);
2544-
2543+
25452544 ACL_CHECK (aclDestroyTensor (acl_src_trans_tensor));
25462545 ACL_CHECK (aclDestroyTensor (acl_dst_trans_tensor));
25472546 break ;
0 commit comments