23
23
// ===----------------------------------------------------------------------===//
24
24
25
25
#include < CL/sycl.hpp>
26
+ #include < algorithm>
26
27
#include < complex>
27
28
#include < cstdint>
28
29
#include < pybind11/complex.h>
@@ -663,12 +664,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
663
664
}
664
665
}
665
666
666
- std::shared_ptr<shT> shp_shape = std::make_shared<shT>(simplified_shape);
667
- std::shared_ptr<shT> shp_src_strides =
668
- std::make_shared<shT>(simplified_src_strides);
669
- std::shared_ptr<shT> shp_dst_strides =
670
- std::make_shared<shT>(simplified_dst_strides);
671
-
672
667
// Generic implementation
673
668
auto copy_and_cast_fn =
674
669
copy_and_cast_generic_dispatch_table[dst_type_id][src_type_id];
@@ -682,77 +677,50 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
682
677
throw std::runtime_error (" Unabled to allocate device memory" );
683
678
}
684
679
685
- sycl::event copy_shape_ev =
686
- exec_q.copy <py::ssize_t >(shp_shape->data (), shape_strides, nd);
680
+ // create host temporary for packed shape and strides managed by shared
681
+ // pointer
682
+ std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683
+ std::copy (simplified_shape.begin (), simplified_shape.end (),
684
+ shp_host_shape_strides->begin ());
687
685
688
- exec_q.submit ([&](sycl::handler &cgh) {
689
- cgh.depends_on (copy_shape_ev);
690
- cgh.host_task ([shp_shape]() {
691
- // increment shared pointer ref-count to keep it alive
692
- // till copy operation completes;
693
- });
694
- });
695
-
696
- sycl::event copy_src_strides_ev;
697
686
if (src_strides == nullptr ) {
698
- std::shared_ptr<shT> shp_contig_src_strides =
699
- std::make_shared<shT>((src_flags & USM_ARRAY_C_CONTIGUOUS)
700
- ? c_contiguous_strides (nd, shape)
701
- : f_contiguous_strides (nd, shape));
702
- copy_src_strides_ev = exec_q.copy <py::ssize_t >(
703
- shp_contig_src_strides->data (), shape_strides + nd, nd);
704
- exec_q.submit ([&](sycl::handler &cgh) {
705
- cgh.depends_on (copy_src_strides_ev);
706
- cgh.host_task ([shp_contig_src_strides]() {
707
- // increment shared pointer ref-count to keep it alive
708
- // till copy operation completes;
709
- });
710
- });
687
+ const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688
+ ? c_contiguous_strides (nd, shape)
689
+ : f_contiguous_strides (nd, shape);
690
+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
691
+ shp_host_shape_strides->begin () + nd);
711
692
}
712
693
else {
713
- copy_src_strides_ev = exec_q.copy <py::ssize_t >(shp_src_strides->data (),
714
- shape_strides + nd, nd);
715
- exec_q.submit ([&](sycl::handler &cgh) {
716
- cgh.depends_on (copy_src_strides_ev);
717
- cgh.host_task ([shp_src_strides]() {
718
- // increment shared pointer ref-count to keep it alive
719
- // till copy operation completes;
720
- });
721
- });
694
+ std::copy (simplified_src_strides.begin (), simplified_src_strides.end (),
695
+ shp_host_shape_strides->begin () + nd);
722
696
}
723
697
724
- sycl::event copy_dst_strides_ev;
725
698
if (dst_strides == nullptr ) {
726
- std::shared_ptr<shT> shp_contig_dst_strides =
727
- std::make_shared<shT>((dst_flags & USM_ARRAY_C_CONTIGUOUS)
728
- ? c_contiguous_strides (nd, shape)
729
- : f_contiguous_strides (nd, shape));
730
- copy_dst_strides_ev = exec_q.copy <py::ssize_t >(
731
- shp_contig_dst_strides->data (), shape_strides + 2 * nd, nd);
732
- exec_q.submit ([&](sycl::handler &cgh) {
733
- cgh.depends_on (copy_dst_strides_ev);
734
- cgh.host_task ([shp_contig_dst_strides]() {
735
- // increment shared pointer ref-count to keep it alive
736
- // till copy operation completes;
737
- });
738
- });
699
+ const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700
+ ? c_contiguous_strides (nd, shape)
701
+ : f_contiguous_strides (nd, shape);
702
+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
703
+ shp_host_shape_strides->begin () + 2 * nd);
739
704
}
740
705
else {
741
- copy_dst_strides_ev = exec_q.copy <py::ssize_t >(
742
- shp_dst_strides->data (), shape_strides + 2 * nd, nd);
743
- exec_q.submit ([&](sycl::handler &cgh) {
744
- cgh.depends_on (copy_dst_strides_ev);
745
- cgh.host_task ([shp_dst_strides]() {
746
- // increment shared pointer ref-count to keep it alive
747
- // till copy operation completes;
748
- });
749
- });
706
+ std::copy (simplified_dst_strides.begin (), simplified_dst_strides.end (),
707
+ shp_host_shape_strides->begin () + nd);
750
708
}
751
709
710
+ sycl::event copy_shape_ev = exec_q.copy <py::ssize_t >(
711
+ shp_host_shape_strides->data (), shape_strides, 3 * nd);
712
+
713
+ exec_q.submit ([&](sycl::handler &cgh) {
714
+ cgh.depends_on (copy_shape_ev);
715
+ cgh.host_task ([shp_host_shape_strides]() {
716
+ // increment shared pointer ref-count to keep it alive
717
+ // till copy operation completes;
718
+ });
719
+ });
720
+
752
721
sycl::event copy_and_cast_generic_ev = copy_and_cast_fn (
753
722
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
754
- dst_offset, depends,
755
- {copy_shape_ev, copy_src_strides_ev, copy_dst_strides_ev});
723
+ dst_offset, depends, {copy_shape_ev});
756
724
757
725
// async free of shape_strides temporary
758
726
auto ctx = exec_q.get_context ();
0 commit comments