Skip to content

Commit 22cdb5a

Browse files
Changed passing shapes/strides to kernel in copy_usm_ndarray_for_reshape
Insteads of invoking 4 copy kernels, it is more expedient to pack them on the host and use single copy kernel to reduce kernel submission overhead.wq
1 parent f16e932 commit 22cdb5a

File tree

1 file changed

+47
-50
lines changed

1 file changed

+47
-50
lines changed

dpctl/tensor/libtensor/source/tensor_py.cpp

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -932,99 +932,96 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
932932

933933
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
934934

935+
// packed_shape_strides = [src_shape, src_strides, dst_shape, dst_strides]
935936
py::ssize_t *packed_shapes_strides =
936937
sycl::malloc_device<py::ssize_t>(2 * (src_nd + dst_nd), exec_q);
937938

938939
if (packed_shapes_strides == nullptr) {
939940
throw std::runtime_error("Unabled to allocate device memory");
940941
}
941942

942-
sycl::event src_shape_copy_ev =
943-
exec_q.copy<py::ssize_t>(src_shape, packed_shapes_strides, src_nd);
944-
sycl::event dst_shape_copy_ev = exec_q.copy<py::ssize_t>(
945-
dst_shape, packed_shapes_strides + 2 * src_nd, dst_nd);
943+
using shT = std::vector<py::ssize_t>;
944+
std::shared_ptr<shT> packed_host_shapes_strides_shp =
945+
std::make_shared<shT>(2 * (src_nd + dst_nd));
946+
947+
std::copy(src_shape, src_shape + src_nd,
948+
packed_host_shapes_strides_shp->begin());
949+
std::copy(dst_shape, dst_shape + dst_nd,
950+
packed_host_shapes_strides_shp->begin() + 2 * src_nd);
946951

947952
const py::ssize_t *src_strides = src.get_strides_raw();
948-
sycl::event src_strides_copy_ev;
949953
if (src_strides == nullptr) {
950-
using shT = std::vector<py::ssize_t>;
951954
int src_flags = src.get_flags();
952-
std::shared_ptr<shT> contig_src_strides_shp;
953955
if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
954-
contig_src_strides_shp =
955-
std::make_shared<shT>(c_contiguous_strides(src_nd, src_shape));
956+
const shT &src_contig_strides =
957+
c_contiguous_strides(src_nd, src_shape);
958+
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
959+
packed_host_shapes_strides_shp->begin() + src_nd);
956960
}
957961
else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
958-
contig_src_strides_shp =
959-
std::make_shared<shT>(f_contiguous_strides(src_nd, src_shape));
962+
const shT &src_contig_strides =
963+
c_contiguous_strides(src_nd, src_shape);
964+
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
965+
packed_host_shapes_strides_shp->begin() + src_nd);
960966
}
961967
else {
962-
sycl::event::wait({src_shape_copy_ev, dst_shape_copy_ev});
963968
sycl::free(packed_shapes_strides, exec_q);
964969
throw std::runtime_error(
965970
"Invalid src array encountered: in copy_for_reshape function");
966971
}
967-
src_strides_copy_ev =
968-
exec_q.copy<py::ssize_t>(contig_src_strides_shp->data(),
969-
packed_shapes_strides + src_nd, src_nd);
970-
exec_q.submit([&](sycl::handler &cgh) {
971-
cgh.depends_on(src_strides_copy_ev);
972-
cgh.host_task([contig_src_strides_shp]() {
973-
// Capturing shared pointer ensure it is freed after its data
974-
// are copied into packed USM vector
975-
});
976-
});
977972
}
978973
else {
979-
src_strides_copy_ev = exec_q.copy<py::ssize_t>(
980-
src_strides, packed_shapes_strides + src_nd, src_nd);
974+
std::copy(src_strides, src_strides + src_nd,
975+
packed_host_shapes_strides_shp->begin() + src_nd);
981976
}
982977

983978
const py::ssize_t *dst_strides = dst.get_strides_raw();
984-
sycl::event dst_strides_copy_ev;
985979
if (dst_strides == nullptr) {
986-
using shT = std::vector<py::ssize_t>;
987980
int dst_flags = dst.get_flags();
988-
std::shared_ptr<shT> contig_dst_strides_shp;
989981
if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
990-
contig_dst_strides_shp =
991-
std::make_shared<shT>(c_contiguous_strides(dst_nd, dst_shape));
982+
const shT &dst_contig_strides =
983+
c_contiguous_strides(dst_nd, dst_shape);
984+
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
985+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
986+
dst_nd);
992987
}
993988
else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
994-
contig_dst_strides_shp =
995-
std::make_shared<shT>(f_contiguous_strides(dst_nd, dst_shape));
989+
const shT &dst_contig_strides =
990+
f_contiguous_strides(dst_nd, dst_shape);
991+
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
992+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
993+
dst_nd);
996994
}
997995
else {
998-
sycl::event::wait(
999-
{src_shape_copy_ev, dst_shape_copy_ev, src_strides_copy_ev});
1000996
sycl::free(packed_shapes_strides, exec_q);
1001997
throw std::runtime_error(
1002998
"Invalid dst array encountered: in copy_for_reshape function");
1003999
}
1004-
dst_strides_copy_ev = exec_q.copy<py::ssize_t>(
1005-
contig_dst_strides_shp->data(),
1006-
packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1007-
exec_q.submit([&](sycl::handler &cgh) {
1008-
cgh.depends_on(dst_strides_copy_ev);
1009-
cgh.host_task([contig_dst_strides_shp]() {
1010-
// Capturing shared pointer ensure it is freed after its data
1011-
// are copied into packed USM vector
1012-
});
1013-
});
10141000
}
10151001
else {
1016-
dst_strides_copy_ev = exec_q.copy<py::ssize_t>(
1017-
dst_strides, packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1002+
std::copy(dst_strides, dst_strides + dst_nd,
1003+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
1004+
dst_nd);
10181005
}
10191006

1007+
// copy packed shapes and strides from host to devices
1008+
sycl::event packed_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
1009+
packed_host_shapes_strides_shp->data(), packed_shapes_strides,
1010+
packed_host_shapes_strides_shp->size());
1011+
exec_q.submit([&](sycl::handler &cgh) {
1012+
cgh.depends_on(packed_shape_strides_copy_ev);
1013+
cgh.host_task([packed_host_shapes_strides_shp] {
1014+
// Capturing shared pointer ensures that the underlying vector is
1015+
// not destroyed until after its data are copied into packed USM
1016+
// vector
1017+
});
1018+
});
1019+
10201020
char *src_data = src.get_data();
10211021
char *dst_data = dst.get_data();
10221022

1023-
std::vector<sycl::event> all_deps(depends.size() + 4);
1024-
all_deps.push_back(src_shape_copy_ev);
1025-
all_deps.push_back(dst_shape_copy_ev);
1026-
all_deps.push_back(src_strides_copy_ev);
1027-
all_deps.push_back(dst_strides_copy_ev);
1023+
std::vector<sycl::event> all_deps(depends.size() + 1);
1024+
all_deps.push_back(packed_shape_strides_copy_ev);
10281025
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
10291026

10301027
sycl::event copy_for_reshape_event =

0 commit comments

Comments
 (0)