Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,9 @@ def __and__(self, other):
# '__array_prepare__',
# '__array_priority__',
# '__array_struct__',
# '__array_ufunc__',

__array_ufunc__ = None

# '__array_wrap__',

def __array_namespace__(self, /, *, api_version=None):
Expand Down
14 changes: 14 additions & 0 deletions dpnp/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def get_all_dtypes(
return dtypes


def get_array(xp, a):
"""
Cast input array `a` to a type supported by `xp` initerface.

Implicit conversion of either DPNP or DPCTL array to a NumPy array is not
allowed. Input array has to be explicitly casted with `asnumpy` function.

"""

if xp is numpy and dpnp.is_supported_array_type(a):
return dpnp.asnumpy(a)
return a


def generate_random_numpy_array(
shape,
dtype=None,
Expand Down
3 changes: 2 additions & 1 deletion dpnp/tests/test_arraycreation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .helper import (
assert_dtype_allclose,
get_all_dtypes,
get_array,
)
from .third_party.cupy import testing

Expand Down Expand Up @@ -768,7 +769,7 @@ def test_space_numpy_dtype(func, start_dtype, stop_dtype):
],
)
def test_linspace_arrays(start, stop):
func = lambda xp: xp.linspace(start, stop, 10)
func = lambda xp: xp.linspace(get_array(xp, start), get_array(xp, stop), 10)
assert func(numpy).shape == func(dpnp).shape


Expand Down
15 changes: 7 additions & 8 deletions dpnp/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,7 @@ def test_matrix_rank(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a)
dp_rank = dpnp.linalg.matrix_rank(a_dp)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize("dtype", get_all_dtypes())
@pytest.mark.parametrize(
Expand All @@ -1953,7 +1953,7 @@ def test_matrix_rank_hermitian(self, data, dtype):

np_rank = numpy.linalg.matrix_rank(a, hermitian=True)
dp_rank = dpnp.linalg.matrix_rank(a_dp, hermitian=True)
assert np_rank == dp_rank
assert dp_rank.asnumpy() == np_rank

@pytest.mark.parametrize(
"high_tol, low_tol",
Expand Down Expand Up @@ -1986,15 +1986,15 @@ def test_matrix_rank_tolerance(self, high_tol, low_tol):
dp_rank_high_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_high_tol
)
assert np_rank_high_tol == dp_rank_high_tol
assert dp_rank_high_tol.asnumpy() == np_rank_high_tol

np_rank_low_tol = numpy.linalg.matrix_rank(
a, hermitian=True, tol=low_tol
)
dp_rank_low_tol = dpnp.linalg.matrix_rank(
a_dp, hermitian=True, tol=dp_low_tol
)
assert np_rank_low_tol == dp_rank_low_tol
assert dp_rank_low_tol.asnumpy() == np_rank_low_tol

# rtol kwarg was added in numpy 2.0
@testing.with_requires("numpy>=2.0")
Expand Down Expand Up @@ -2807,15 +2807,14 @@ def check_decomposition(
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
dpnp_diag_s[..., i, i] = dp_s[..., i]
reconstructed = dpnp.dot(dp_u, dpnp.dot(dpnp_diag_s, dp_vt))
# TODO: use assert dpnp.allclose() inside check_decomposition()
# when it will support complex dtypes
assert_allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert dpnp.allclose(dp_a, reconstructed, rtol=tol, atol=1e-4)

assert_allclose(dp_s, np_s, rtol=tol, atol=1e-03)

if compute_vt:
for i in range(min(dp_a.shape[-2], dp_a.shape[-1])):
if np_u[..., 0, i] * dp_u[..., 0, i] < 0:
if np_u[..., 0, i] * dpnp.asnumpy(dp_u[..., 0, i]) < 0:
np_u[..., :, i] = -np_u[..., :, i]
np_vt[..., i, :] = -np_vt[..., i, :]
for i in range(numpy.count_nonzero(np_s > tol)):
Expand Down
6 changes: 5 additions & 1 deletion dpnp/tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .helper import (
assert_dtype_allclose,
get_all_dtypes,
get_array,
get_complex_dtypes,
get_float_complex_dtypes,
get_float_dtypes,
Expand Down Expand Up @@ -1232,7 +1233,10 @@ def test_axes(self):
def test_axes_type(self, axes):
a = numpy.ones((50, 40, 3))
ia = dpnp.array(a)
assert_equal(dpnp.rot90(ia, axes=axes), numpy.rot90(a, axes=axes))
assert_equal(
dpnp.rot90(ia, axes=axes),
numpy.rot90(a, axes=get_array(numpy, axes)),
)

def test_rotation_axes(self):
a = numpy.arange(8).reshape((2, 2, 2))
Expand Down
21 changes: 21 additions & 0 deletions dpnp/tests/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,27 @@ def test_wrong_api_version(self, api_version):
)


class TestArrayUfunc:
def test_add(self):
a = numpy.ones(10)
b = dpnp.ones(10)
msg = "An array must be any of supported type"

with assert_raises_regex(TypeError, msg):
a + b

with assert_raises_regex(TypeError, msg):
b + a

def test_add_inplace(self):
a = numpy.ones(10)
b = dpnp.ones(10)
with assert_raises_regex(
TypeError, "operand 'dpnp_array' does not support ufuncs"
):
a += b


class TestItem:
@pytest.mark.parametrize("args", [2, 7, (1, 2), (2, 0)])
def test_basic(self, args):
Expand Down
Loading