@@ -932,99 +932,96 @@ copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
932
932
933
933
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
934
934
935
+ // packed_shape_strides = [src_shape, src_strides, dst_shape, dst_strides]
935
936
py::ssize_t *packed_shapes_strides =
936
937
sycl::malloc_device<py::ssize_t >(2 * (src_nd + dst_nd), exec_q);
937
938
938
939
if (packed_shapes_strides == nullptr ) {
939
940
throw std::runtime_error (" Unabled to allocate device memory" );
940
941
}
941
942
942
- sycl::event src_shape_copy_ev =
943
- exec_q.copy <py::ssize_t >(src_shape, packed_shapes_strides, src_nd);
944
- sycl::event dst_shape_copy_ev = exec_q.copy <py::ssize_t >(
945
- dst_shape, packed_shapes_strides + 2 * src_nd, dst_nd);
943
+ using shT = std::vector<py::ssize_t >;
944
+ std::shared_ptr<shT> packed_host_shapes_strides_shp =
945
+ std::make_shared<shT>(2 * (src_nd + dst_nd));
946
+
947
+ std::copy (src_shape, src_shape + src_nd,
948
+ packed_host_shapes_strides_shp->begin ());
949
+ std::copy (dst_shape, dst_shape + dst_nd,
950
+ packed_host_shapes_strides_shp->begin () + 2 * src_nd);
946
951
947
952
const py::ssize_t *src_strides = src.get_strides_raw ();
948
- sycl::event src_strides_copy_ev;
949
953
if (src_strides == nullptr ) {
950
- using shT = std::vector<py::ssize_t >;
951
954
int src_flags = src.get_flags ();
952
- std::shared_ptr<shT> contig_src_strides_shp;
953
955
if (src_flags & USM_ARRAY_C_CONTIGUOUS) {
954
- contig_src_strides_shp =
955
- std::make_shared<shT>(c_contiguous_strides (src_nd, src_shape));
956
+ const shT &src_contig_strides =
957
+ c_contiguous_strides (src_nd, src_shape);
958
+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
959
+ packed_host_shapes_strides_shp->begin () + src_nd);
956
960
}
957
961
else if (src_flags & USM_ARRAY_F_CONTIGUOUS) {
958
- contig_src_strides_shp =
959
- std::make_shared<shT>(f_contiguous_strides (src_nd, src_shape));
962
+ const shT &src_contig_strides =
963
+ c_contiguous_strides (src_nd, src_shape);
964
+ std::copy (src_contig_strides.begin (), src_contig_strides.end (),
965
+ packed_host_shapes_strides_shp->begin () + src_nd);
960
966
}
961
967
else {
962
- sycl::event::wait ({src_shape_copy_ev, dst_shape_copy_ev});
963
968
sycl::free (packed_shapes_strides, exec_q);
964
969
throw std::runtime_error (
965
970
" Invalid src array encountered: in copy_for_reshape function" );
966
971
}
967
- src_strides_copy_ev =
968
- exec_q.copy <py::ssize_t >(contig_src_strides_shp->data (),
969
- packed_shapes_strides + src_nd, src_nd);
970
- exec_q.submit ([&](sycl::handler &cgh) {
971
- cgh.depends_on (src_strides_copy_ev);
972
- cgh.host_task ([contig_src_strides_shp]() {
973
- // Capturing shared pointer ensure it is freed after its data
974
- // are copied into packed USM vector
975
- });
976
- });
977
972
}
978
973
else {
979
- src_strides_copy_ev = exec_q. copy <py:: ssize_t >(
980
- src_strides, packed_shapes_strides + src_nd, src_nd);
974
+ std::copy (src_strides, src_strides + src_nd,
975
+ packed_host_shapes_strides_shp-> begin () + src_nd);
981
976
}
982
977
983
978
const py::ssize_t *dst_strides = dst.get_strides_raw ();
984
- sycl::event dst_strides_copy_ev;
985
979
if (dst_strides == nullptr ) {
986
- using shT = std::vector<py::ssize_t >;
987
980
int dst_flags = dst.get_flags ();
988
- std::shared_ptr<shT> contig_dst_strides_shp;
989
981
if (dst_flags & USM_ARRAY_C_CONTIGUOUS) {
990
- contig_dst_strides_shp =
991
- std::make_shared<shT>(c_contiguous_strides (dst_nd, dst_shape));
982
+ const shT &dst_contig_strides =
983
+ c_contiguous_strides (dst_nd, dst_shape);
984
+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
985
+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
986
+ dst_nd);
992
987
}
993
988
else if (dst_flags & USM_ARRAY_F_CONTIGUOUS) {
994
- contig_dst_strides_shp =
995
- std::make_shared<shT>(f_contiguous_strides (dst_nd, dst_shape));
989
+ const shT &dst_contig_strides =
990
+ f_contiguous_strides (dst_nd, dst_shape);
991
+ std::copy (dst_contig_strides.begin (), dst_contig_strides.end (),
992
+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
993
+ dst_nd);
996
994
}
997
995
else {
998
- sycl::event::wait (
999
- {src_shape_copy_ev, dst_shape_copy_ev, src_strides_copy_ev});
1000
996
sycl::free (packed_shapes_strides, exec_q);
1001
997
throw std::runtime_error (
1002
998
" Invalid dst array encountered: in copy_for_reshape function" );
1003
999
}
1004
- dst_strides_copy_ev = exec_q.copy <py::ssize_t >(
1005
- contig_dst_strides_shp->data (),
1006
- packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1007
- exec_q.submit ([&](sycl::handler &cgh) {
1008
- cgh.depends_on (dst_strides_copy_ev);
1009
- cgh.host_task ([contig_dst_strides_shp]() {
1010
- // Capturing shared pointer ensure it is freed after its data
1011
- // are copied into packed USM vector
1012
- });
1013
- });
1014
1000
}
1015
1001
else {
1016
- dst_strides_copy_ev = exec_q.copy <py::ssize_t >(
1017
- dst_strides, packed_shapes_strides + 2 * src_nd + dst_nd, dst_nd);
1002
+ std::copy (dst_strides, dst_strides + dst_nd,
1003
+ packed_host_shapes_strides_shp->begin () + 2 * src_nd +
1004
+ dst_nd);
1018
1005
}
1019
1006
1007
+ // copy packed shapes and strides from host to devices
1008
+ sycl::event packed_shape_strides_copy_ev = exec_q.copy <py::ssize_t >(
1009
+ packed_host_shapes_strides_shp->data (), packed_shapes_strides,
1010
+ packed_host_shapes_strides_shp->size ());
1011
+ exec_q.submit ([&](sycl::handler &cgh) {
1012
+ cgh.depends_on (packed_shape_strides_copy_ev);
1013
+ cgh.host_task ([packed_host_shapes_strides_shp] {
1014
+ // Capturing shared pointer ensures that the underlying vector is
1015
+ // not destroyed until after its data are copied into packed USM
1016
+ // vector
1017
+ });
1018
+ });
1019
+
1020
1020
char *src_data = src.get_data ();
1021
1021
char *dst_data = dst.get_data ();
1022
1022
1023
- std::vector<sycl::event> all_deps (depends.size () + 4 );
1024
- all_deps.push_back (src_shape_copy_ev);
1025
- all_deps.push_back (dst_shape_copy_ev);
1026
- all_deps.push_back (src_strides_copy_ev);
1027
- all_deps.push_back (dst_strides_copy_ev);
1023
+ std::vector<sycl::event> all_deps (depends.size () + 1 );
1024
+ all_deps.push_back (packed_shape_strides_copy_ev);
1028
1025
all_deps.insert (std::end (all_deps), std::begin (depends), std::end (depends));
1029
1026
1030
1027
sycl::event copy_for_reshape_event =
0 commit comments