@@ -751,71 +751,46 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
751751}
752752
753753void ggml_cann_dup (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
754- ggml_tensor* src0 = dst->src [0 ];
755-
756- aclTensor* acl_src = ggml_cann_create_tensor (src0);
757- aclTensor* acl_dst = ggml_cann_create_tensor (dst);
758- if (ggml_are_same_shape (src0, dst)) {
759- if (dst->type == src0->type ) {
760- cann_copy (ctx, acl_src, acl_dst);
761- } else {
762- aclnn_cast (ctx, acl_src, acl_dst, ggml_cann_type_mapping (dst->type ));
754+ ggml_tensor* src0 = dst->src [0 ];
755+ void * src_trans_buffer = src0->data ;
756+ ggml_cann_pool_alloc src_buffer_allocator;
757+ if (!ggml_is_contiguous (src0)) {
758+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
759+ src_buffer_allocator.alloc (ctx.pool (),
760+ ggml_nelements (src0) * ggml_type_size (src0->type ));
761+ src_trans_buffer = src_buffer_allocator.get ();
762+ size_t src_trans_nb[GGML_MAX_DIMS];
763+ src_trans_nb[0 ] = ggml_type_size (src0->type );
764+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
765+ src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
763766 }
764- } else {
765- if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst)) {
766- if (dst->type == src0->type ) {
767- size_t cpy_size = ggml_nbytes (dst);
768- ggml_cann_async_memcpy (ctx, dst->data , src0->data , cpy_size,
769- ACL_MEMCPY_DEVICE_TO_DEVICE);
770- return ;
771- } else {
772- ggml_cann_pool_alloc src_buffer_allocator (
773- ctx.pool (),
774- ggml_nelements (dst) * ggml_type_size (dst->type ));
775- void * src_trans_buffer = src_buffer_allocator.get ();
776- size_t src_trans_nb[GGML_MAX_DIMS];
777- src_trans_nb[0 ] = ggml_type_size (dst->type );
778- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
779- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
780- }
781- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
782- src_trans_buffer, ggml_cann_type_mapping (dst->type ),
783- ggml_type_size (dst->type ), src0->ne , src_trans_nb,
784- GGML_MAX_DIMS);
785-
786- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
787- size_t cpy_size = ggml_nbytes (dst);
788- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
789- ACL_MEMCPY_DEVICE_TO_DEVICE);
790- ggml_cann_release_resources (ctx, src_trans_tensor);
791- return ;
792- }
793- } else if (ggml_is_contiguous (dst)) {
794- ggml_cann_pool_alloc src_buffer_allocator (
795- ctx.pool (), ggml_nelements (dst) * ggml_type_size (dst->type ));
796- void * src_trans_buffer = src_buffer_allocator.get ();
797- size_t src_trans_nb[GGML_MAX_DIMS];
798- src_trans_nb[0 ] = ggml_type_size (dst->type );
799- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
800- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
801- }
802- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
803- src_trans_buffer, ggml_cann_type_mapping (dst->type ),
804- ggml_type_size (dst->type ), src0->ne , src_trans_nb,
805- GGML_MAX_DIMS);
767+ aclTensor* src_trans_tensor = ggml_cann_create_tensor (
768+ src_trans_buffer, ggml_cann_type_mapping (src0->type ),
769+ ggml_type_size (src0->type ), src0->ne , src_trans_nb,
770+ GGML_MAX_DIMS);
771+ cann_copy (ctx, acl_src, src_trans_tensor);
772+ ggml_cann_release_resources (ctx, acl_src, src_trans_tensor);
773+ }
806774
807- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
775+ size_t src_reshape_nb[GGML_MAX_DIMS];
776+ src_reshape_nb[0 ] = ggml_type_size (src0->type );
777+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
778+ src_reshape_nb[i] = src_reshape_nb[i - 1 ] * dst->ne [i - 1 ];
779+ }
780+
781+ aclTensor* trans_acl_src = ggml_cann_create_tensor (src_trans_buffer,
782+ ggml_cann_type_mapping (src0->type ),ggml_type_size (src0->type ),
783+ dst->ne , src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
784+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
808785
809- size_t cpy_size = ggml_nbytes (dst);
810- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
811- ACL_MEMCPY_DEVICE_TO_DEVICE);
812- ggml_cann_release_resources (ctx, src_trans_tensor);
813- return ;
814- } else {
815- GGML_ABORT (" Unsupport dst is not tontiguous." );
816- }
786+ if (dst->type == src0->type ) {
787+ cann_copy (ctx, trans_acl_src, acl_dst);
788+ } else {
789+ aclnn_cast (ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping (dst->type ));
817790 }
818- ggml_cann_release_resources (ctx, acl_src, acl_dst);
791+
792+ ggml_cann_release_resources (ctx, trans_acl_src, acl_dst);
793+ return ;
819794}
820795
821796/* *
0 commit comments