@@ -753,69 +753,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
753753void ggml_cann_dup (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
754754 ggml_tensor* src0 = dst->src [0 ];
755755
756- aclTensor* acl_src = ggml_cann_create_tensor (src0);
757- aclTensor* acl_dst = ggml_cann_create_tensor (dst);
758756 if (ggml_are_same_shape (src0, dst)) {
757+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
758+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
759759 if (dst->type == src0->type ) {
760760 cann_copy (ctx, acl_src, acl_dst);
761761 } else {
762762 aclnn_cast (ctx, acl_src, acl_dst, ggml_cann_type_mapping (dst->type ));
763763 }
764+ ggml_cann_release_resources (ctx, acl_src, acl_dst);
764765 } else {
765- if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst)) {
766- if (dst->type == src0->type ) {
767- size_t cpy_size = ggml_nbytes (dst);
768- ggml_cann_async_memcpy (ctx, dst->data , src0->data , cpy_size,
769- ACL_MEMCPY_DEVICE_TO_DEVICE);
770- return ;
771- } else {
772- ggml_cann_pool_alloc src_buffer_allocator (
773- ctx.pool (),
774- ggml_nelements (dst) * ggml_type_size (dst->type ));
775- void * src_trans_buffer = src_buffer_allocator.get ();
776- size_t src_trans_nb[GGML_MAX_DIMS];
777- src_trans_nb[0 ] = ggml_type_size (dst->type );
778- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
779- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
780- }
781- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
782- src_trans_buffer, ggml_cann_type_mapping (dst->type ),
783- ggml_type_size (dst->type ), src0->ne , src_trans_nb,
784- GGML_MAX_DIMS);
785-
786- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
787- size_t cpy_size = ggml_nbytes (dst);
788- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
789- ACL_MEMCPY_DEVICE_TO_DEVICE);
790- ggml_cann_release_resources (ctx, src_trans_tensor);
791- return ;
792- }
793- } else if (ggml_is_contiguous (dst)) {
794- ggml_cann_pool_alloc src_buffer_allocator (
795- ctx.pool (), ggml_nelements (dst) * ggml_type_size (dst->type ));
796- void * src_trans_buffer = src_buffer_allocator.get ();
766+ void * src_trans_buffer = src0->data ;
767+ ggml_cann_pool_alloc src_buffer_allocator;
768+ if (!ggml_is_contiguous (src0)) {
769+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
770+ src_buffer_allocator.alloc (ctx.pool (),
771+ ggml_nelements (src0) * ggml_type_size (src0->type ));
772+ src_trans_buffer = src_buffer_allocator.get ();
797773 size_t src_trans_nb[GGML_MAX_DIMS];
798- src_trans_nb[0 ] = ggml_type_size (dst ->type );
774+ src_trans_nb[0 ] = ggml_type_size (src0 ->type );
799775 for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
800776 src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
801777 }
802778 aclTensor* src_trans_tensor = ggml_cann_create_tensor (
803- src_trans_buffer, ggml_cann_type_mapping (dst ->type ),
804- ggml_type_size (dst ->type ), src0->ne , src_trans_nb,
779+ src_trans_buffer, ggml_cann_type_mapping (src0 ->type ),
780+ ggml_type_size (src0 ->type ), src0->ne , src_trans_nb,
805781 GGML_MAX_DIMS);
782+ cann_copy (ctx, acl_src, src_trans_tensor);
783+ ggml_cann_release_resources (ctx, acl_src, src_trans_tensor);
784+ }
806785
807- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
786+ size_t src_reshape_nb[GGML_MAX_DIMS];
787+ src_reshape_nb[0 ] = ggml_type_size (src0->type );
788+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
789+ src_reshape_nb[i] = src_reshape_nb[i - 1 ] * dst->ne [i - 1 ];
790+ }
808791
809- size_t cpy_size = ggml_nbytes (dst);
810- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
811- ACL_MEMCPY_DEVICE_TO_DEVICE);
812- ggml_cann_release_resources (ctx, src_trans_tensor);
813- return ;
792+ aclTensor* trans_acl_src = ggml_cann_create_tensor (src_trans_buffer,
793+ ggml_cann_type_mapping (src0->type ),ggml_type_size (src0->type ),
794+ dst->ne , src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
795+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
796+
797+ if (dst->type == src0->type ) {
798+ cann_copy (ctx, trans_acl_src, acl_dst);
814799 } else {
815- GGML_ABORT ( " Unsupport dst is not contiguous. " );
800+ aclnn_cast (ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping (dst-> type ) );
816801 }
802+ ggml_cann_release_resources (ctx, trans_acl_src, acl_dst);
817803 }
818- ggml_cann_release_resources (ctx, acl_src, acl_dst) ;
804+ return ;
819805}
820806
821807/* *
0 commit comments