Skip to content

Commit 87ed0ad

Browse files
Added copy_for_reshape to contain implementation
1 parent 6f64ba0 commit 87ed0ad

File tree

4 files changed

+315
-207
lines changed

4 files changed

+315
-207
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pybind11_add_module(${python_module_name} MODULE
2020
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/tensor_py.cpp
2121
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/simplify_iteration_space.cpp
2222
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_and_cast_usm_to_usm.cpp
23+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
2324
)
2425
target_link_options(${python_module_name} PRIVATE -fsycl-device-code-split=per_kernel)
2526
target_include_directories(${python_module_name}
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===----------------------------------------------------------------------===//
24+
25+
#include <CL/sycl.hpp>
26+
#include <utility>
27+
#include <vector>
28+
29+
#include "copy_for_reshape.hpp"
30+
#include "dpctl4pybind11.hpp"
31+
#include "kernels/copy_and_cast.hpp"
32+
#include "utils/type_dispatch.hpp"
33+
#include <pybind11/pybind11.h>
34+
35+
namespace dpctl
36+
{
37+
namespace tensor
38+
{
39+
namespace py_internal
40+
{
41+
42+
namespace _ns = dpctl::tensor::detail;
43+
44+
using dpctl::tensor::kernels::copy_and_cast::copy_for_reshape_fn_ptr_t;
45+
using dpctl::utils::keep_args_alive;
46+
47+
// define static vector
48+
static copy_for_reshape_fn_ptr_t
49+
copy_for_reshape_generic_dispatch_vector[_ns::num_types];
50+
51+
/*
52+
* Copies src into dst (same data type) of different shapes by using flat
53+
* iterations.
54+
*
55+
* Equivalent to the following loop:
56+
*
57+
* for i for range(src.size):
58+
* dst[np.multi_index(i, dst.shape)] = src[np.multi_index(i, src.shape)]
59+
*/
60+
std::pair<sycl::event, sycl::event>
61+
copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
62+
dpctl::tensor::usm_ndarray dst,
63+
py::ssize_t shift,
64+
sycl::queue exec_q,
65+
const std::vector<sycl::event> &depends)
66+
{
67+
py::ssize_t src_nelems = src.get_size();
68+
py::ssize_t dst_nelems = dst.get_size();
69+
70+
// Must have the same number of elements
71+
if (src_nelems != dst_nelems) {
72+
throw py::value_error(
73+
"copy_usm_ndarray_for_reshape requires src and dst to "
74+
"have the same number of elements.");
75+
}
76+
77+
int src_typenum = src.get_typenum();
78+
int dst_typenum = dst.get_typenum();
79+
80+
// typenames must be the same
81+
if (src_typenum != dst_typenum) {
82+
throw py::value_error(
83+
"copy_usm_ndarray_for_reshape requires src and dst to "
84+
"have the same type.");
85+
}
86+
87+
if (src_nelems == 0) {
88+
return std::make_pair(sycl::event(), sycl::event());
89+
}
90+
91+
// destination must be ample enough to accomodate all elements
92+
{
93+
auto dst_offsets = dst.get_minmax_offsets();
94+
py::ssize_t range =
95+
static_cast<py::ssize_t>(dst_offsets.second - dst_offsets.first);
96+
if (range + 1 < src_nelems) {
97+
throw py::value_error(
98+
"Destination array can not accomodate all the "
99+
"elements of source array.");
100+
}
101+
}
102+
103+
// check same contexts
104+
sycl::queue src_q = src.get_queue();
105+
sycl::queue dst_q = dst.get_queue();
106+
107+
if (!dpctl::utils::queues_are_compatible(exec_q, {src_q, dst_q})) {
108+
throw py::value_error(
109+
"Execution queue is not compatible with allocation queues");
110+
}
111+
112+
if (src_nelems == 1) {
113+
// handle special case of 1-element array
114+
int src_elemsize = src.get_elemsize();
115+
char *src_data = src.get_data();
116+
char *dst_data = dst.get_data();
117+
sycl::event copy_ev =
118+
exec_q.copy<char>(src_data, dst_data, src_elemsize);
119+
return std::make_pair(keep_args_alive(exec_q, {src, dst}, {copy_ev}),
120+
copy_ev);
121+
}
122+
123+
// dimensions may be different
124+
int src_nd = src.get_ndim();
125+
int dst_nd = dst.get_ndim();
126+
127+
const py::ssize_t *src_shape = src.get_shape_raw();
128+
const py::ssize_t *dst_shape = dst.get_shape_raw();
129+
130+
auto array_types = dpctl::tensor::detail::usm_ndarray_types();
131+
int type_id = array_types.typenum_to_lookup_id(src_typenum);
132+
133+
auto fn = copy_for_reshape_generic_dispatch_vector[type_id];
134+
135+
// packed_shape_strides = [src_shape, src_strides, dst_shape, dst_strides]
136+
py::ssize_t *packed_shapes_strides =
137+
sycl::malloc_device<py::ssize_t>(2 * (src_nd + dst_nd), exec_q);
138+
139+
if (packed_shapes_strides == nullptr) {
140+
throw std::runtime_error("Unabled to allocate device memory");
141+
}
142+
143+
using usm_host_allocatorT =
144+
sycl::usm_allocator<py::ssize_t, sycl::usm::alloc::host>;
145+
using shT = std::vector<py::ssize_t, usm_host_allocatorT>;
146+
usm_host_allocatorT allocator(exec_q);
147+
std::shared_ptr<shT> packed_host_shapes_strides_shp =
148+
std::make_shared<shT>(2 * (src_nd + dst_nd), allocator);
149+
150+
std::copy(src_shape, src_shape + src_nd,
151+
packed_host_shapes_strides_shp->begin());
152+
std::copy(dst_shape, dst_shape + dst_nd,
153+
packed_host_shapes_strides_shp->begin() + 2 * src_nd);
154+
155+
const py::ssize_t *src_strides = src.get_strides_raw();
156+
if (src_strides == nullptr) {
157+
if (src.is_c_contiguous()) {
158+
const auto &src_contig_strides =
159+
c_contiguous_strides(src_nd, src_shape);
160+
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
161+
packed_host_shapes_strides_shp->begin() + src_nd);
162+
}
163+
else if (src.is_f_contiguous()) {
164+
const auto &src_contig_strides =
165+
f_contiguous_strides(src_nd, src_shape);
166+
std::copy(src_contig_strides.begin(), src_contig_strides.end(),
167+
packed_host_shapes_strides_shp->begin() + src_nd);
168+
}
169+
else {
170+
sycl::free(packed_shapes_strides, exec_q);
171+
throw std::runtime_error(
172+
"Invalid src array encountered: in copy_for_reshape function");
173+
}
174+
}
175+
else {
176+
std::copy(src_strides, src_strides + src_nd,
177+
packed_host_shapes_strides_shp->begin() + src_nd);
178+
}
179+
180+
const py::ssize_t *dst_strides = dst.get_strides_raw();
181+
if (dst_strides == nullptr) {
182+
if (dst.is_c_contiguous()) {
183+
const auto &dst_contig_strides =
184+
c_contiguous_strides(dst_nd, dst_shape);
185+
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
186+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
187+
dst_nd);
188+
}
189+
else if (dst.is_f_contiguous()) {
190+
const auto &dst_contig_strides =
191+
f_contiguous_strides(dst_nd, dst_shape);
192+
std::copy(dst_contig_strides.begin(), dst_contig_strides.end(),
193+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
194+
dst_nd);
195+
}
196+
else {
197+
sycl::free(packed_shapes_strides, exec_q);
198+
throw std::runtime_error(
199+
"Invalid dst array encountered: in copy_for_reshape function");
200+
}
201+
}
202+
else {
203+
std::copy(dst_strides, dst_strides + dst_nd,
204+
packed_host_shapes_strides_shp->begin() + 2 * src_nd +
205+
dst_nd);
206+
}
207+
208+
// copy packed shapes and strides from host to devices
209+
sycl::event packed_shape_strides_copy_ev = exec_q.copy<py::ssize_t>(
210+
packed_host_shapes_strides_shp->data(), packed_shapes_strides,
211+
packed_host_shapes_strides_shp->size());
212+
exec_q.submit([&](sycl::handler &cgh) {
213+
cgh.depends_on(packed_shape_strides_copy_ev);
214+
cgh.host_task([packed_host_shapes_strides_shp] {
215+
// Capturing shared pointer ensures that the underlying vector is
216+
// not destroyed until after its data are copied into packed USM
217+
// vector
218+
});
219+
});
220+
221+
char *src_data = src.get_data();
222+
char *dst_data = dst.get_data();
223+
224+
std::vector<sycl::event> all_deps(depends.size() + 1);
225+
all_deps.push_back(packed_shape_strides_copy_ev);
226+
all_deps.insert(std::end(all_deps), std::begin(depends), std::end(depends));
227+
228+
sycl::event copy_for_reshape_event =
229+
fn(exec_q, shift, src_nelems, src_nd, dst_nd, packed_shapes_strides,
230+
src_data, dst_data, all_deps);
231+
232+
exec_q.submit([&](sycl::handler &cgh) {
233+
cgh.depends_on(copy_for_reshape_event);
234+
auto ctx = exec_q.get_context();
235+
cgh.host_task([packed_shapes_strides, ctx]() {
236+
sycl::free(packed_shapes_strides, ctx);
237+
});
238+
});
239+
240+
return std::make_pair(
241+
keep_args_alive(exec_q, {src, dst}, {copy_for_reshape_event}),
242+
copy_for_reshape_event);
243+
}
244+
245+
void init_copy_for_reshape_dispatch_vectors(void)
246+
{
247+
using namespace dpctl::tensor::detail;
248+
using dpctl::tensor::kernels::copy_and_cast::CopyForReshapeGenericFactory;
249+
250+
DispatchVectorBuilder<copy_for_reshape_fn_ptr_t,
251+
CopyForReshapeGenericFactory, num_types>
252+
dvb;
253+
dvb.populate_dispatch_vector(copy_for_reshape_generic_dispatch_vector);
254+
}
255+
256+
} // namespace py_internal
257+
} // namespace tensor
258+
} // namespace dpctl
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===----------- Implementation of _tensor_impl module ---------*-C++-*-/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2022 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines functions of dpctl.tensor._tensor_impl extensions
23+
//===----------------------------------------------------------------------===//
24+
25+
#pragma once
26+
#include <CL/sycl.hpp>
27+
#include <utility>
28+
#include <vector>
29+
30+
#include "dpctl4pybind11.hpp"
31+
#include <pybind11/pybind11.h>
32+
33+
namespace dpctl
34+
{
35+
namespace tensor
36+
{
37+
namespace py_internal
38+
{
39+
40+
extern std::pair<sycl::event, sycl::event>
41+
copy_usm_ndarray_for_reshape(dpctl::tensor::usm_ndarray src,
42+
dpctl::tensor::usm_ndarray dst,
43+
py::ssize_t shift,
44+
sycl::queue exec_q,
45+
const std::vector<sycl::event> &depends = {});
46+
47+
extern void init_copy_for_reshape_dispatch_vectors();
48+
49+
} // namespace py_internal
50+
} // namespace tensor
51+
} // namespace dpctl

0 commit comments

Comments
 (0)