diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 0a394c291585..dd790c8e56a1 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -460,3 +460,29 @@ def test_clip(): expected = numpy.clip(numpy_array, 3, 7) assert_array_equal(expected, result) + + +def test_rmatmul_dpnp_array(): + a = dpnp.ones(10) + b = dpnp.ones(10) + + class Dummy(dpnp.ndarray): + def __init__(self, x): + self._array_obj = x.get_array() + + def __matmul__(self, other): + return NotImplemented + + d = Dummy(a) + + result = d @ b + expected = a @ b + assert (result == expected).all() + + +def test_rmatmul_numpy_array(): + a = dpnp.ones(10) + b = numpy.ones(10) + + with pytest.raises(TypeError): + b @ a diff --git a/dpnp/tests/test_utils.py b/dpnp/tests/test_utils.py new file mode 100644 index 000000000000..89e97b75d5e5 --- /dev/null +++ b/dpnp/tests/test_utils.py @@ -0,0 +1,69 @@ +import dpctl +import dpctl.tensor as dpt +import numpy +import pytest + +import dpnp + + +class TestIsSupportedArrayOrScalar: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpnp.array(1), + dpt.asarray([1, 2, 3]), + ], + ) + def test_valid_arrays(self, array): + assert dpnp.is_supported_array_or_scalar(array) is True + + @pytest.mark.parametrize( + "value", + [ + 42, + True, + "1", + ], + ) + def test_valid_scalars(self, value): + assert dpnp.is_supported_array_or_scalar(value) is True + + @pytest.mark.parametrize( + "array", + [ + [1, 2, 3], + (1, 2, 3), + None, + numpy.array([1, 2, 3]), + ], + ) + def test_invalid_arrays(self, array): + assert not dpnp.is_supported_array_or_scalar(array) is True + + +class TestSynchronizeArrayData: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpt.asarray([1, 2, 3]), + ], + ) + def test_synchronize_array_data(self, array): + a_copy = dpnp.copy(array, sycl_queue=array.sycl_queue) + try: + dpnp.synchronize_array_data(a_copy) + except Exception as e: + pytest.fail(f"synchronize_array_data failed: {e}") + + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3], + numpy.array([1, 2, 3]), + ], + ) + def test_unsupported_type(self, input): + with pytest.raises(TypeError): + dpnp.synchronize_array_data(input)