Skip to content

Commit 832a981

Browse files
committed
Implemented advanced indexing kernels
- Kernels for _take, _put - Python API functions for take, put
1 parent c2d7928 commit 832a981

File tree

8 files changed

+1862
-15
lines changed

8 files changed

+1862
-15
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pybind11_add_module(${python_module_name} MODULE
2323
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_numpy_ndarray_into_usm_ndarray.cpp
2424
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/copy_for_reshape.cpp
2525
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/linear_sequences.cpp
26+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/advanced_indexing.cpp
2627
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/eye_ctor.cpp
2728
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/full_ctor.cpp
2829
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/triul_ctor.cpp

dpctl/tensor/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from dpctl.tensor._device import Device
6060
from dpctl.tensor._dlpack import from_dlpack
61+
from dpctl.tensor._indexing_functions import put, take
6162
from dpctl.tensor._manipulation_functions import (
6263
broadcast_arrays,
6364
broadcast_to,
@@ -112,6 +113,8 @@
112113
"expand_dims",
113114
"permute_dims",
114115
"squeeze",
116+
"take",
117+
"put",
115118
"from_numpy",
116119
"to_numpy",
117120
"asnumpy",

dpctl/tensor/_copy_utils.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
16+
import operator
17+
1618
import numpy as np
19+
from numpy.core.numeric import normalize_axis_index
1720

1821
import dpctl
1922
import dpctl.memory as dpm
@@ -449,14 +452,25 @@ def _mock_take_multi_index(ary, inds, p):
449452
raise IndexError(
450453
"arrays used as indices must be of integer (or boolean) type"
451454
)
452-
ary_np = dpt.asnumpy(ary)
453-
ind_np = (slice(None),) * p + tuple(dpt.asnumpy(ind) for ind in inds)
454-
res_np = ary_np[ind_np]
455+
inds = dpt.broadcast_arrays(*inds)
456+
ary_ndim = ary.ndim
457+
if ary_ndim > 0:
458+
p = operator.index(p)
459+
p = normalize_axis_index(p, ary_ndim)
460+
461+
res_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
462+
else:
463+
res_shape = inds[0].shape
455464
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
456465
res = dpt.empty(
457-
res_np.shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
466+
res_shape, dtype=ary.dtype, usm_type=res_usm_type, sycl_queue=exec_q
458467
)
459-
res[...] = res_np
468+
469+
hev, _ = ti._take(
470+
src=ary, ind=inds, dst=res, axis_start=p, mode=0, sycl_queue=exec_q
471+
)
472+
hev.wait()
473+
460474
return res
461475

462476

@@ -492,7 +506,7 @@ def _mock_place(ary, ary_mask, p, vals):
492506

493507

494508
def _mock_put_multi_index(ary, inds, p, vals):
495-
if isinstance(vals, dpt.ums_ndarray):
509+
if isinstance(vals, dpt.usm_ndarray):
496510
queues_ = [ary.sycl_queue, vals.sycl_queue]
497511
usm_types_ = [ary.usm_type, vals.usm_type]
498512
else:
@@ -522,14 +536,27 @@ def _mock_put_multi_index(ary, inds, p, vals):
522536
raise IndexError(
523537
"arrays used as indices must be of integer (or boolean) type"
524538
)
525-
ary_np = dpt.asnumpy(ary)
526-
if isinstance(vals, dpt.usm_ndarray) or hasattr(
527-
vals, "__sycl_usm_array_interface__"
528-
):
529-
vals_np = dpt.asnumpy(vals)
539+
540+
inds = dpt.broadcast_arrays(*inds)
541+
ary_ndim = ary.ndim
542+
if ary_ndim > 0:
543+
p = operator.index(p)
544+
p = normalize_axis_index(p, ary_ndim)
545+
vals_shape = ary.shape[:p] + inds[0].shape + ary.shape[p + len(inds) :]
530546
else:
531-
vals_np = vals
532-
ind_np = (slice(None),) * p + tuple(dpt.asnumpy(ind) for ind in inds)
533-
ary_np[ind_np] = vals_np
534-
ary[...] = ary_np
547+
vals_shape = inds[0].shape
548+
549+
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
550+
if not isinstance(vals, dpt.usm_ndarray):
551+
vals = dpt.asarray(
552+
vals, ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
553+
)
554+
555+
vals = dpt.broadcast_to(vals, vals_shape)
556+
557+
hev, _ = ti._put(
558+
dst=ary, ind=inds, val=vals, axis_start=p, mode=0, sycl_queue=exec_q
559+
)
560+
hev.wait()
561+
535562
return

dpctl/tensor/_indexing_functions.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2022 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import operator
18+
19+
import numpy as np
20+
from numpy.core.numeric import normalize_axis_index
21+
22+
import dpctl
23+
import dpctl.tensor as dpt
24+
from dpctl.tensor._tensor_impl import _put, _take
25+
26+
27+
def take(x, indices, /, *, axis=None, mode="clip"):
28+
if not isinstance(x, dpt.usm_ndarray):
29+
raise TypeError(
30+
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
31+
)
32+
33+
if not isinstance(indices, list) and not isinstance(indices, tuple):
34+
indices = (indices,)
35+
36+
queues_ = [
37+
x.sycl_queue,
38+
]
39+
usm_types_ = [
40+
x.usm_type,
41+
]
42+
43+
for i in indices:
44+
if not isinstance(i, dpt.usm_ndarray):
45+
raise TypeError(
46+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
47+
type(i)
48+
)
49+
)
50+
if not np.issubdtype(i.dtype, np.integer):
51+
raise TypeError(
52+
"`indices` expected integer data type, got `{}`".format(i.dtype)
53+
)
54+
queues_.append(i.sycl_queue)
55+
usm_types_.append(i.usm_type)
56+
exec_q = dpctl.utils.get_execution_queue(queues_)
57+
if exec_q is None:
58+
raise dpctl.utils.ExecutionPlacementError(
59+
"Can not automatically determine where to allocate the "
60+
"result or performance execution. "
61+
"Use `usm_ndarray.to_device` method to migrate data to "
62+
"be associated with the same queue."
63+
)
64+
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
65+
66+
modes = {"clip": 0, "wrap": 1}
67+
try:
68+
mode = modes[mode]
69+
except KeyError:
70+
raise ValueError("`mode` must be `clip` or `wrap`.")
71+
72+
x_ndim = x.ndim
73+
if axis is None:
74+
if x_ndim > 1:
75+
raise ValueError(
76+
"`axis` cannot be `None` for array of dimension `{}`".format(
77+
x_ndim
78+
)
79+
)
80+
axis = 0
81+
82+
indices = dpt.broadcast_arrays(*indices)
83+
if x_ndim > 0:
84+
axis = operator.index(axis)
85+
axis = normalize_axis_index(axis, x_ndim)
86+
res_shape = (
87+
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
88+
)
89+
else:
90+
res_shape = indices[0].shape
91+
92+
res = dpt.empty(
93+
res_shape, dtype=x.dtype, usm_type=res_usm_type, sycl_queue=exec_q
94+
)
95+
96+
hev, _ = _take(x, indices, res, axis, mode, sycl_queue=exec_q)
97+
hev.wait()
98+
99+
return res
100+
101+
102+
def put(x, indices, vals, /, *, axis=None, mode="clip"):
103+
if not isinstance(x, dpt.usm_ndarray):
104+
raise TypeError(
105+
"Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x))
106+
)
107+
queues_ = [
108+
x.sycl_queue,
109+
]
110+
usm_types_ = [
111+
x.usm_type,
112+
]
113+
114+
if not isinstance(indices, list) and not isinstance(indices, tuple):
115+
indices = (indices,)
116+
117+
for i in indices:
118+
if not isinstance(i, dpt.usm_ndarray):
119+
raise TypeError(
120+
"`indices` expected `dpt.usm_ndarray`, got `{}`.".format(
121+
type(i)
122+
)
123+
)
124+
if not np.issubdtype(i.dtype, np.integer):
125+
raise TypeError(
126+
"`indices` expected integer data type, got `{}`".format(i.dtype)
127+
)
128+
queues_.append(i.sycl_queue)
129+
usm_types_.append(i.usm_type)
130+
exec_q = dpctl.utils.get_execution_queue(queues_)
131+
if exec_q is None:
132+
raise dpctl.utils.ExecutionPlacementError(
133+
"Can not automatically determine where to allocate the "
134+
"result or performance execution. "
135+
"Use `usm_ndarray.to_device` method to migrate data to "
136+
"be associated with the same queue."
137+
)
138+
val_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
139+
140+
modes = {"clip": 0, "wrap": 1}
141+
try:
142+
mode = modes[mode]
143+
except KeyError:
144+
raise ValueError("`mode` must be `wrap`, or `clip`.")
145+
146+
# when axis is none, array is treated as 1D
147+
if axis is None:
148+
x = dpt.reshape(x, (x.size,), copy=False)
149+
axis = 0
150+
151+
indices = dpt.broadcast_arrays(*indices)
152+
x_ndim = x.ndim
153+
if x_ndim > 0:
154+
axis = operator.index(axis)
155+
axis = normalize_axis_index(axis, x_ndim)
156+
157+
val_shape = (
158+
x.shape[:axis] + indices[0].shape + x.shape[axis + len(indices) :]
159+
)
160+
else:
161+
val_shape = indices[0].shape
162+
163+
if not isinstance(vals, dpt.usm_ndarray):
164+
vals = dpt.asarray(
165+
vals, dtype=x.dtype, usm_type=val_usm_type, sycl_queue=exec_q
166+
)
167+
168+
vals = dpt.broadcast_to(vals, val_shape)
169+
170+
hev, _ = _put(x, indices, vals, axis, mode, sycl_queue=exec_q)
171+
hev.wait()

0 commit comments

Comments
 (0)