Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
return view


cdef inline void _check_0d_scalar_conversion(object usm_ary) except *:
"Raise TypeError if array cannot be converted to a Python scalar"
if (usm_ary.ndim != 0):
raise TypeError(
"only 0-dimensional arrays can be converted to Python scalars"
)


cdef int _copy_writable(int lhs_flags, int rhs_flags):
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
Expand Down Expand Up @@ -1147,6 +1155,7 @@ cdef class usm_ndarray:

def __float__(self):
if self.size == 1:
_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__float__()

Expand All @@ -1156,6 +1165,7 @@ cdef class usm_ndarray:

def __complex__(self):
if self.size == 1:
_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__complex__()

Expand All @@ -1165,6 +1175,7 @@ cdef class usm_ndarray:

def __int__(self):
if self.size == 1:
_check_0d_scalar_conversion(self)
view = _as_zero_dim_ndarray(self)
return view.__int__()

Expand Down
75 changes: 51 additions & 24 deletions dpctl/tests/test_usm_ndarray_ctor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import numpy as np
import pytest
from numpy.testing import assert_raises_regex

import dpctl
import dpctl.memory as dpm
Expand Down Expand Up @@ -282,34 +283,60 @@ def test_properties(dt):
V.mT


@pytest.mark.parametrize("func", [bool, float, int, complex])
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_func(func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert func(X) == func(Y)
class TestCopyScalar:
def test_copy_bool_scalar_with_func(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert bool(X) == bool(Y)

@pytest.mark.parametrize("func", [float, int, complex])
def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
# Non-0D numeric arrays must not be convertible to Python scalars
if len(shape) != 0:
assert_raises_regex(TypeError, "only 0-dimensional arrays", func, X)
else:
# 0D arrays are allowed to convert
assert func(X) == func(Y)

def test_copy_bool_scalar_with_method(self, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert getattr(X, "__bool__")() == getattr(Y, "__bool__")()

@pytest.mark.parametrize(
"method", ["__bool__", "__float__", "__int__", "__complex__"]
)
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
def test_copy_scalar_with_method(method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
assert getattr(X, method)() == getattr(Y, method)()
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
def test_copy_numeric_scalar_with_method(self, method, shape, dtype):
try:
X = dpt.usm_ndarray(shape, dtype=dtype)
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
Y = np.arange(1, X.size + 1, dtype=dtype)
X.usm_data.copy_from_host(Y.view("|u1"))
Y.shape = tuple()
if len(shape) != 0:
assert_raises_regex(
TypeError, "only 0-dimensional arrays", getattr(X, method)
)
else:
assert getattr(X, method)() == getattr(Y, method)()


@pytest.mark.parametrize("func", [bool, float, int, complex])
Expand Down
Loading