@@ -753,69 +753,55 @@ static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
753
753
void ggml_cann_dup (ggml_backend_cann_context& ctx, ggml_tensor* dst) {
754
754
ggml_tensor* src0 = dst->src [0 ];
755
755
756
- aclTensor* acl_src = ggml_cann_create_tensor (src0);
757
- aclTensor* acl_dst = ggml_cann_create_tensor (dst);
758
756
if (ggml_are_same_shape (src0, dst)) {
757
+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
758
+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
759
759
if (dst->type == src0->type ) {
760
760
cann_copy (ctx, acl_src, acl_dst);
761
761
} else {
762
762
aclnn_cast (ctx, acl_src, acl_dst, ggml_cann_type_mapping (dst->type ));
763
763
}
764
+ ggml_cann_release_resources (ctx, acl_src, acl_dst);
764
765
} else {
765
- if (ggml_is_contiguous (src0) && ggml_is_contiguous (dst)) {
766
- if (dst->type == src0->type ) {
767
- size_t cpy_size = ggml_nbytes (dst);
768
- ggml_cann_async_memcpy (ctx, dst->data , src0->data , cpy_size,
769
- ACL_MEMCPY_DEVICE_TO_DEVICE);
770
- return ;
771
- } else {
772
- ggml_cann_pool_alloc src_buffer_allocator (
773
- ctx.pool (),
774
- ggml_nelements (dst) * ggml_type_size (dst->type ));
775
- void * src_trans_buffer = src_buffer_allocator.get ();
776
- size_t src_trans_nb[GGML_MAX_DIMS];
777
- src_trans_nb[0 ] = ggml_type_size (dst->type );
778
- for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
779
- src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
780
- }
781
- aclTensor* src_trans_tensor = ggml_cann_create_tensor (
782
- src_trans_buffer, ggml_cann_type_mapping (dst->type ),
783
- ggml_type_size (dst->type ), src0->ne , src_trans_nb,
784
- GGML_MAX_DIMS);
785
-
786
- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
787
- size_t cpy_size = ggml_nbytes (dst);
788
- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
789
- ACL_MEMCPY_DEVICE_TO_DEVICE);
790
- ggml_cann_release_resources (ctx, src_trans_tensor);
791
- return ;
792
- }
793
- } else if (ggml_is_contiguous (dst)) {
794
- ggml_cann_pool_alloc src_buffer_allocator (
795
- ctx.pool (), ggml_nelements (dst) * ggml_type_size (dst->type ));
796
- void * src_trans_buffer = src_buffer_allocator.get ();
766
+ void * src_trans_buffer = src0->data ;
767
+ ggml_cann_pool_alloc src_buffer_allocator;
768
+ if (!ggml_is_contiguous (src0)) {
769
+ aclTensor* acl_src = ggml_cann_create_tensor (src0);
770
+ src_buffer_allocator.alloc (ctx.pool (),
771
+ ggml_nelements (src0) * ggml_type_size (src0->type ));
772
+ src_trans_buffer = src_buffer_allocator.get ();
797
773
size_t src_trans_nb[GGML_MAX_DIMS];
798
- src_trans_nb[0 ] = ggml_type_size (dst ->type );
774
+ src_trans_nb[0 ] = ggml_type_size (src0 ->type );
799
775
for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
800
776
src_trans_nb[i] = src_trans_nb[i - 1 ] * src0->ne [i - 1 ];
801
777
}
802
778
aclTensor* src_trans_tensor = ggml_cann_create_tensor (
803
- src_trans_buffer, ggml_cann_type_mapping (dst ->type ),
804
- ggml_type_size (dst ->type ), src0->ne , src_trans_nb,
779
+ src_trans_buffer, ggml_cann_type_mapping (src0 ->type ),
780
+ ggml_type_size (src0 ->type ), src0->ne , src_trans_nb,
805
781
GGML_MAX_DIMS);
782
+ cann_copy (ctx, acl_src, src_trans_tensor);
783
+ ggml_cann_release_resources (ctx, acl_src, src_trans_tensor);
784
+ }
806
785
807
- aclnn_cast (ctx, acl_src, src_trans_tensor, ggml_cann_type_mapping (dst->type ));
786
+ size_t src_reshape_nb[GGML_MAX_DIMS];
787
+ src_reshape_nb[0 ] = ggml_type_size (src0->type );
788
+ for (int i = 1 ; i < GGML_MAX_DIMS; i++) {
789
+ src_reshape_nb[i] = src_reshape_nb[i - 1 ] * dst->ne [i - 1 ];
790
+ }
808
791
809
- size_t cpy_size = ggml_nbytes (dst);
810
- ggml_cann_async_memcpy (ctx, dst->data , src_trans_buffer, cpy_size,
811
- ACL_MEMCPY_DEVICE_TO_DEVICE);
812
- ggml_cann_release_resources (ctx, src_trans_tensor);
813
- return ;
792
+ aclTensor* trans_acl_src = ggml_cann_create_tensor (src_trans_buffer,
793
+ ggml_cann_type_mapping (src0->type ),ggml_type_size (src0->type ),
794
+ dst->ne , src_reshape_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
795
+ aclTensor* acl_dst = ggml_cann_create_tensor (dst);
796
+
797
+ if (dst->type == src0->type ) {
798
+ cann_copy (ctx, trans_acl_src, acl_dst);
814
799
} else {
815
- GGML_ABORT ( " Unsupport dst is not contiguous. " );
800
+ aclnn_cast (ctx, trans_acl_src, acl_dst, ggml_cann_type_mapping (dst-> type ) );
816
801
}
802
+ ggml_cann_release_resources (ctx, trans_acl_src, acl_dst);
817
803
}
818
- ggml_cann_release_resources (ctx, acl_src, acl_dst) ;
804
+ return ;
819
805
}
820
806
821
807
/* *
0 commit comments