Skip to content

Commit d66c494

Browse files
Added _constants, and extended advanced indexing tests
1 parent 6589dfa commit d66c494

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@
8383
from dpctl.tensor._reshape import reshape
8484
from dpctl.tensor._usmarray import usm_ndarray
8585

86+
from ._constants import e, inf, nan, newaxis, pi
87+
8688
__all__ = [
8789
"Device",
8890
"usm_ndarray",
@@ -141,4 +143,9 @@
141143
"print_options",
142144
"usm_ndarray_repr",
143145
"usm_ndarray_str",
146+
"newaxis",
147+
"e",
148+
"pi",
149+
"nan",
150+
"inf",
144151
]

dpctl/tensor/_constants.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2023 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import numpy as np
18+
19+
newaxis = None
20+
21+
pi = np.pi
22+
e = np.e
23+
nan = np.nan
24+
inf = np.inf

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,102 @@ def test_basic_slice3():
5858
assert y.ndim == x.ndim
5959
assert y.shape == x.shape
6060
assert y.strides == x.strides
61+
62+
63+
def test_basic_slice4():
64+
q = get_queue_or_skip()
65+
n0, n1 = 5, 3
66+
x = dpt.empty((n0, n1), dtype="f4", sycl_queue=q)
67+
y = x[::-1]
68+
assert isinstance(y, dpt.usm_ndarray)
69+
assert y.shape == x.shape
70+
assert y.strides == (-x.strides[0], x.strides[1])
71+
actual_offset = y.__sycl_usm_array_interface__["offset"]
72+
assert actual_offset == (n0 - 1) * n1
73+
74+
75+
def test_basic_slice5():
76+
q = get_queue_or_skip()
77+
n0, n1 = 5, 3
78+
x = dpt.empty((n0, n1), dtype="c8", sycl_queue=q)
79+
y = x[:, ::-1]
80+
assert isinstance(y, dpt.usm_ndarray)
81+
assert y.shape == x.shape
82+
assert y.strides == (x.strides[0], -x.strides[1])
83+
actual_offset = y.__sycl_usm_array_interface__["offset"]
84+
assert actual_offset == (n1 - 1)
85+
86+
87+
def test_basic_slice6():
88+
q = get_queue_or_skip()
89+
i0, n0, n1 = 2, 4, 3
90+
x = dpt.empty((n0, n1), dtype="c8", sycl_queue=q)
91+
y = x[i0, ::-1]
92+
assert isinstance(y, dpt.usm_ndarray)
93+
assert y.shape == (x.shape[1],)
94+
assert y.strides == (-x.strides[1],)
95+
actual_offset = y.__sycl_usm_array_interface__["offset"]
96+
expected_offset = i0 * x.strides[0] + (n1 - 1) * x.strides[1]
97+
assert actual_offset == expected_offset
98+
99+
100+
def test_basic_slice7():
101+
q = get_queue_or_skip()
102+
n0, n1, n2 = 5, 3, 2
103+
x = dpt.empty((n0, n1, n2), dtype="?", sycl_queue=q)
104+
y = x[..., ::-1]
105+
assert isinstance(y, dpt.usm_ndarray)
106+
assert y.shape == x.shape
107+
assert y.strides == (
108+
x.strides[0],
109+
x.strides[1],
110+
-x.strides[2],
111+
)
112+
actual_offset = y.__sycl_usm_array_interface__["offset"]
113+
expected_offset = (n2 - 1) * x.strides[2]
114+
assert actual_offset == expected_offset
115+
116+
117+
def test_basic_slice8():
118+
q = get_queue_or_skip()
119+
n0, n1 = 3, 7
120+
x = dpt.empty((n0, n1), dtype="u1", sycl_queue=q)
121+
y = x[..., dpt.newaxis]
122+
assert isinstance(y, dpt.usm_ndarray)
123+
assert y.shape == (n0, n1, 1)
124+
assert y.strides == (n1, 1, 0)
125+
126+
127+
def test_basic_slice9():
128+
q = get_queue_or_skip()
129+
n0, n1 = 3, 7
130+
x = dpt.empty((n0, n1), dtype="u8", sycl_queue=q)
131+
y = x[dpt.newaxis, ...]
132+
assert isinstance(y, dpt.usm_ndarray)
133+
assert y.shape == (1, n0, n1)
134+
assert y.strides == (0, n1, 1)
135+
136+
137+
def test_basic_slice10():
138+
q = get_queue_or_skip()
139+
n0, n1, n2 = 3, 7, 5
140+
x = dpt.empty((n0, n1, n2), dtype="u1", sycl_queue=q)
141+
y = x[dpt.newaxis, ..., :]
142+
assert isinstance(y, dpt.usm_ndarray)
143+
assert y.shape == (1, n0, n1, n2)
144+
assert y.strides == (0, n1 * n2, n2, 1)
145+
146+
147+
def test_advanced_slice1():
148+
q = get_queue_or_skip()
149+
ii = dpt.asarray([1, 2], sycl_queue=q)
150+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
151+
y = x[ii]
152+
assert isinstance(y, dpt.usm_ndarray)
153+
assert y.shape == ii.shape
154+
assert y.strides == (1,)
155+
# FIXME, once usm_ndarray.__equal__ is implemented,
156+
# use of asnumpy should be removed
157+
assert all(
158+
dpt.asnumpy(x[ii[k]]) == dpt.asnumpy(y[k]) for k in range(ii.shape[0])
159+
)

0 commit comments

Comments
 (0)