Skip to content

Commit f16e932

Browse files
Improved transferring of shapes/strides to device for usm_ndarray copy_cast kernels
1 parent 0390cfe commit f16e932

File tree

1 file changed

+32
-64
lines changed

1 file changed

+32
-64
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 32 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
//===----------------------------------------------------------------------===//
2424

2525
#include <CL/sycl.hpp>
26+
#include <algorithm>
2627
#include <complex>
2728
#include <cstdint>
2829
#include <pybind11/complex.h>
@@ -663,12 +664,6 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
663664
}
664665
}
665666

666-
std::shared_ptr<shT> shp_shape = std::make_shared<shT>(simplified_shape);
667-
std::shared_ptr<shT> shp_src_strides =
668-
std::make_shared<shT>(simplified_src_strides);
669-
std::shared_ptr<shT> shp_dst_strides =
670-
std::make_shared<shT>(simplified_dst_strides);
671-
672667
// Generic implementation
673668
auto copy_and_cast_fn =
674669
copy_and_cast_generic_dispatch_table[dst_type_id][src_type_id];
@@ -682,77 +677,50 @@ copy_usm_ndarray_into_usm_ndarray(dpctl::tensor::usm_ndarray src,
682677
throw std::runtime_error("Unabled to allocate device memory");
683678
}
684679

685-
sycl::event copy_shape_ev =
686-
exec_q.copy<py::ssize_t>(shp_shape->data(), shape_strides, nd);
680+
// create host temporary for packed shape and strides managed by shared
681+
// pointer
682+
std::shared_ptr<shT> shp_host_shape_strides = std::make_shared<shT>(3 * nd);
683+
std::copy(simplified_shape.begin(), simplified_shape.end(),
684+
shp_host_shape_strides->begin());
687685

688-
exec_q.submit([&](sycl::handler &cgh) {
689-
cgh.depends_on(copy_shape_ev);
690-
cgh.host_task([shp_shape]() {
691-
// increment shared pointer ref-count to keep it alive
692-
// till copy operation completes;
693-
});
694-
});
695-
696-
sycl::event copy_src_strides_ev;
697686
if (src_strides == nullptr) {
698-
std::shared_ptr<shT> shp_contig_src_strides =
699-
std::make_shared<shT>((src_flags & USM_ARRAY_C_CONTIGUOUS)
700-
? c_contiguous_strides(nd, shape)
701-
: f_contiguous_strides(nd, shape));
702-
copy_src_strides_ev = exec_q.copy<py::ssize_t>(
703-
shp_contig_src_strides->data(), shape_strides + nd, nd);
704-
exec_q.submit([&](sycl::handler &cgh) {
705-
cgh.depends_on(copy_src_strides_ev);
706-
cgh.host_task([shp_contig_src_strides]() {
707-
// increment shared pointer ref-count to keep it alive
708-
// till copy operation completes;
709-
});
710-
});
687+
const shT &src_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
688+
? c_contiguous_strides(nd, shape)
689+
: f_contiguous_strides(nd, shape);
690+
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
691+
shp_host_shape_strides->begin() + nd);
711692
}
712693
else {
713-
copy_src_strides_ev = exec_q.copy<py::ssize_t>(shp_src_strides->data(),
714-
shape_strides + nd, nd);
715-
exec_q.submit([&](sycl::handler &cgh) {
716-
cgh.depends_on(copy_src_strides_ev);
717-
cgh.host_task([shp_src_strides]() {
718-
// increment shared pointer ref-count to keep it alive
719-
// till copy operation completes;
720-
});
721-
});
694+
std::copy(simplified_src_strides.begin(), simplified_src_strides.end(),
695+
shp_host_shape_strides->begin() + nd);
722696
}
723697

724-
sycl::event copy_dst_strides_ev;
725698
if (dst_strides == nullptr) {
726-
std::shared_ptr<shT> shp_contig_dst_strides =
727-
std::make_shared<shT>((dst_flags & USM_ARRAY_C_CONTIGUOUS)
728-
? c_contiguous_strides(nd, shape)
729-
: f_contiguous_strides(nd, shape));
730-
copy_dst_strides_ev = exec_q.copy<py::ssize_t>(
731-
shp_contig_dst_strides->data(), shape_strides + 2 * nd, nd);
732-
exec_q.submit([&](sycl::handler &cgh) {
733-
cgh.depends_on(copy_dst_strides_ev);
734-
cgh.host_task([shp_contig_dst_strides]() {
735-
// increment shared pointer ref-count to keep it alive
736-
// till copy operation completes;
737-
});
738-
});
699+
const shT &dst_contig_strides = (src_flags & USM_ARRAY_C_CONTIGUOUS)
700+
? c_contiguous_strides(nd, shape)
701+
: f_contiguous_strides(nd, shape);
702+
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
703+
shp_host_shape_strides->begin() + 2 * nd);
739704
}
740705
else {
741-
copy_dst_strides_ev = exec_q.copy<py::ssize_t>(
742-
shp_dst_strides->data(), shape_strides + 2 * nd, nd);
743-
exec_q.submit([&](sycl::handler &cgh) {
744-
cgh.depends_on(copy_dst_strides_ev);
745-
cgh.host_task([shp_dst_strides]() {
746-
// increment shared pointer ref-count to keep it alive
747-
// till copy operation completes;
748-
});
749-
});
706+
std::copy(simplified_dst_strides.begin(), simplified_dst_strides.end(),
707+
shp_host_shape_strides->begin() + nd);
750708
}
751709

710+
sycl::event copy_shape_ev = exec_q.copy<py::ssize_t>(
711+
shp_host_shape_strides->data(), shape_strides, 3 * nd);
712+
713+
exec_q.submit([&](sycl::handler &cgh) {
714+
cgh.depends_on(copy_shape_ev);
715+
cgh.host_task([shp_host_shape_strides]() {
716+
// increment shared pointer ref-count to keep it alive
717+
// till copy operation completes;
718+
});
719+
});
720+
752721
sycl::event copy_and_cast_generic_ev = copy_and_cast_fn(
753722
exec_q, src_nelems, nd, shape_strides, src_data, src_offset, dst_data,
754-
dst_offset, depends,
755-
{copy_shape_ev, copy_src_strides_ev, copy_dst_strides_ev});
723+
dst_offset, depends, {copy_shape_ev});
756724

757725
// async free of shape_strides temporary
758726
auto ctx = exec_q.get_context();

0 commit comments

Comments
 (0)