Skip to content

Commit 0797836

Browse files
authored
Merge pull request #2223 from IntelPython/disallow_conv_to_scalar_ndim
Disallow scalar conversation for non-0D arrays
2 parents a12c1c7 + 5f28594 commit 0797836

File tree

7 files changed

+70
-44
lines changed

7 files changed

+70
-44
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1010

1111
### Changed
1212

13+
* Disallowed scalar conversion for non-0D `tensor.usm_ndarray` per Python Array API specification [gh-2223](https://github.com/IntelPython/dpctl/pull/2223)
14+
1315
### Fixed
1416

1517
### Maintenance

dpctl/tensor/_usmarray.pyx

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
124124
return view
125125

126126

127+
cdef inline void _check_0d_scalar_conversion(object usm_ary) except *:
128+
"Raise TypeError if array cannot be converted to a Python scalar"
129+
if (usm_ary.ndim != 0):
130+
raise TypeError(
131+
"only 0-dimensional arrays can be converted to Python scalars"
132+
)
133+
134+
127135
cdef int _copy_writable(int lhs_flags, int rhs_flags):
128136
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
129137
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
@@ -1132,6 +1140,7 @@ cdef class usm_ndarray:
11321140

11331141
def __bool__(self):
11341142
if self.size == 1:
1143+
_check_0d_scalar_conversion(self)
11351144
view = _as_zero_dim_ndarray(self)
11361145
return view.__bool__()
11371146

@@ -1147,6 +1156,7 @@ cdef class usm_ndarray:
11471156

11481157
def __float__(self):
11491158
if self.size == 1:
1159+
_check_0d_scalar_conversion(self)
11501160
view = _as_zero_dim_ndarray(self)
11511161
return view.__float__()
11521162

@@ -1156,6 +1166,7 @@ cdef class usm_ndarray:
11561166

11571167
def __complex__(self):
11581168
if self.size == 1:
1169+
_check_0d_scalar_conversion(self)
11591170
view = _as_zero_dim_ndarray(self)
11601171
return view.__complex__()
11611172

@@ -1165,6 +1176,7 @@ cdef class usm_ndarray:
11651176

11661177
def __int__(self):
11671178
if self.size == 1:
1179+
_check_0d_scalar_conversion(self)
11681180
view = _as_zero_dim_ndarray(self)
11691181
return view.__int__()
11701182

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import numpy as np
2222
import pytest
23+
from numpy.testing import assert_raises_regex
2324

2425
import dpctl
2526
import dpctl.memory as dpm
@@ -282,34 +283,42 @@ def test_properties(dt):
282283
V.mT
283284

284285

285-
@pytest.mark.parametrize("func", [bool, float, int, complex])
286-
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
287-
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
288-
def test_copy_scalar_with_func(func, shape, dtype):
289-
try:
290-
X = dpt.usm_ndarray(shape, dtype=dtype)
291-
except dpctl.SyclDeviceCreationError:
292-
pytest.skip("No SYCL devices available")
293-
Y = np.arange(1, X.size + 1, dtype=dtype)
294-
X.usm_data.copy_from_host(Y.view("|u1"))
295-
Y.shape = tuple()
296-
assert func(X) == func(Y)
297-
298-
299-
@pytest.mark.parametrize(
300-
"method", ["__bool__", "__float__", "__int__", "__complex__"]
301-
)
302286
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
303287
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
304-
def test_copy_scalar_with_method(method, shape, dtype):
305-
try:
306-
X = dpt.usm_ndarray(shape, dtype=dtype)
307-
except dpctl.SyclDeviceCreationError:
308-
pytest.skip("No SYCL devices available")
309-
Y = np.arange(1, X.size + 1, dtype=dtype)
310-
X.usm_data.copy_from_host(Y.view("|u1"))
311-
Y.shape = tuple()
312-
assert getattr(X, method)() == getattr(Y, method)()
288+
class TestCopyScalar:
289+
@pytest.mark.parametrize("func", [bool, float, int, complex])
290+
def test_copy_scalar_with_func(self, func, shape, dtype):
291+
try:
292+
X = dpt.usm_ndarray(shape, dtype=dtype)
293+
except dpctl.SyclDeviceCreationError:
294+
pytest.skip("No SYCL devices available")
295+
Y = np.arange(1, X.size + 1, dtype=dtype)
296+
X.usm_data.copy_from_host(Y.view("|u1"))
297+
Y = Y.reshape(())
298+
# Non-0D numeric arrays must not be convertible to Python scalars
299+
if len(shape) != 0:
300+
assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X)
301+
else:
302+
# 0D arrays are allowed to convert
303+
assert func(X) == func(Y)
304+
305+
@pytest.mark.parametrize(
306+
"method", ["__bool__", "__float__", "__int__", "__complex__"]
307+
)
308+
def test_copy_scalar_with_method(self, method, shape, dtype):
309+
try:
310+
X = dpt.usm_ndarray(shape, dtype=dtype)
311+
except dpctl.SyclDeviceCreationError:
312+
pytest.skip("No SYCL devices available")
313+
Y = np.arange(1, X.size + 1, dtype=dtype)
314+
X.usm_data.copy_from_host(Y.view("|u1"))
315+
Y = Y.reshape(())
316+
if len(shape) != 0:
317+
assert_raises_regex(
318+
TypeError, "only 0-dimensional arrays", getattr(X, method)
319+
)
320+
else:
321+
assert getattr(X, method)() == getattr(Y, method)()
313322

314323

315324
@pytest.mark.parametrize("func", [bool, float, int, complex])

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,11 +1430,12 @@ def test_nonzero_f_contig():
14301430
mask = dpt.zeros((5, 5), dtype="?", order="F")
14311431
mask[2, 3] = True
14321432

1433-
expected_res = (2, 3)
1434-
res = dpt.nonzero(mask)
1433+
expected_res = np.nonzero(dpt.asnumpy(mask))
1434+
result = dpt.nonzero(mask)
14351435

1436-
assert expected_res == res
1437-
assert mask[res]
1436+
for exp, res in zip(expected_res, result):
1437+
assert_array_equal(dpt.asnumpy(res), exp)
1438+
assert dpt.asnumpy(mask[result]).all()
14381439

14391440

14401441
def test_nonzero_compacting():
@@ -1448,11 +1449,12 @@ def test_nonzero_compacting():
14481449
mask[3, 2, 1] = True
14491450
mask_view = mask[..., :3]
14501451

1451-
expected_res = (3, 2, 1)
1452-
res = dpt.nonzero(mask_view)
1452+
expected_res = np.nonzero(dpt.asnumpy(mask_view))
1453+
result = dpt.nonzero(mask_view)
14531454

1454-
assert expected_res == res
1455-
assert mask_view[res]
1455+
for exp, res in zip(expected_res, result):
1456+
assert_array_equal(dpt.asnumpy(res), exp)
1457+
assert dpt.asnumpy(mask_view[result]).all()
14561458

14571459

14581460
def test_assign_scalar():

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,11 +1438,11 @@ def test_tile_size_1():
14381438
# test for gh-1627 behavior
14391439
res = dpt.tile(x1, reps)
14401440
assert x1.shape == res.shape
1441-
assert x1 == res
1441+
assert_array_equal(dpt.asnumpy(x1), dpt.asnumpy(res))
14421442

14431443
res = dpt.tile(x2, reps)
14441444
assert x2.shape == res.shape
1445-
assert x2 == res
1445+
assert_array_equal(dpt.asnumpy(x2), dpt.asnumpy(res))
14461446

14471447

14481448
def test_tile_prepends_axes():

dpctl/tests/test_usm_ndarray_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def test_mat_ops(namespace):
129129
@pytest.mark.parametrize("namespace", [dpt, Dummy()])
130130
def test_comp_ops(namespace):
131131
try:
132-
X = dpt.ones(1, dtype="u8")
132+
X = dpt.asarray(1, dtype="u8")
133133
except dpctl.SyclDeviceCreationError:
134134
pytest.skip("No SYCL devices available")
135135
X._set_namespace(namespace)

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21+
from numpy.testing import assert_array_equal
2122

2223
import dpctl.tensor as dpt
2324
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -345,19 +346,19 @@ def test_radix_sort_size_1_axis():
345346

346347
x1 = dpt.ones((), dtype="i1")
347348
r1 = dpt.sort(x1, kind="radixsort")
348-
assert r1 == x1
349+
assert_array_equal(dpt.asnumpy(r1), dpt.asnumpy(x1))
349350

350351
x2 = dpt.ones([1], dtype="i1")
351352
r2 = dpt.sort(x2, kind="radixsort")
352-
assert r2 == x2
353+
assert_array_equal(dpt.asnumpy(r2), dpt.asnumpy(x2))
353354

354355
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
355356
r3 = dpt.sort(x3, kind="radixsort")
356-
assert dpt.all(r3 == x3)
357+
assert dpt.asnumpy(r3 == x3).all()
357358

358359
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
359360
r4 = dpt.sort(x4, axis=0, kind="radixsort")
360-
assert dpt.all(r4 == x4)
361+
assert dpt.asnumpy(r4 == x4).all()
361362

362363

363364
def test_radix_argsort_size_1_axis():
@@ -369,12 +370,12 @@ def test_radix_argsort_size_1_axis():
369370

370371
x2 = dpt.ones([1], dtype="i1")
371372
r2 = dpt.argsort(x2, kind="radixsort")
372-
assert r2 == 0
373+
assert dpt.asnumpy(r2 == 0).all()
373374

374375
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
375376
r3 = dpt.argsort(x3, kind="radixsort")
376-
assert dpt.all(r3 == 0)
377+
assert dpt.asnumpy(r3 == 0).all()
377378

378379
x4 = dpt.reshape(dpt.arange(10, dtype="i1"), (1, 10))
379380
r4 = dpt.argsort(x4, axis=0, kind="radixsort")
380-
assert dpt.all(r4 == 0)
381+
assert dpt.asnumpy(r4 == 0).all()

0 commit comments

Comments
 (0)