From 9a568f019bc199609f3fac557998e7d278b5a564 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 24 Jan 2025 01:31:33 -0800 Subject: [PATCH 1/4] Add tests to cover rmatmul operator --- dpnp/tests/test_mathematical.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/dpnp/tests/test_mathematical.py b/dpnp/tests/test_mathematical.py index 500316ab21f3..9c280b45758b 100644 --- a/dpnp/tests/test_mathematical.py +++ b/dpnp/tests/test_mathematical.py @@ -3320,6 +3320,30 @@ def test_matmul_with_offsets(self, sh1, sh2): expected = a[1] @ b[1] assert_dtype_allclose(result, expected) + def test_rmatmul_dpnp_array(self): + a = dpnp.ones(10) + b = dpnp.ones(10) + + class A(dpnp.ndarray): + def __init__(self, x): + self._array_obj = x.get_array() + + def __matmul__(self, other): + return NotImplemented + + a1 = A(a) + + result = a1 @ b + expected = dpnp.matmul(a, b) + assert_dtype_allclose(result, expected) + + def test_rmatmul_numpy_array(self): + a = dpnp.ones(10) + b = numpy.ones(10) + + with pytest.raises(TypeError): + b @ a + class TestMatmulInplace: ALL_DTYPES = get_all_dtypes(no_none=True) From 0d13a6be2c8c693ceb9f27f5dbf7f03c0fc810c0 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 24 Jan 2025 02:15:25 -0800 Subject: [PATCH 2/4] Add a new tests_utill.py file --- dpnp/tests/test_utils.py | 67 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 dpnp/tests/test_utils.py diff --git a/dpnp/tests/test_utils.py b/dpnp/tests/test_utils.py new file mode 100644 index 000000000000..24743dc49717 --- /dev/null +++ b/dpnp/tests/test_utils.py @@ -0,0 +1,67 @@ +import dpctl.tensor as dpt +import numpy +import pytest + +import dpnp + + +class TestIsSupportedArrayOrScalar: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpnp.array(1), + dpt.asarray([1, 2, 3]), + ], + ) + def test_valid_arrays(self, array): + assert dpnp.is_supported_array_or_scalar(array) + + @pytest.mark.parametrize( + "value", + [ + 42, + True, + "1", + ], + ) + def test_valid_scalars(self, value): + assert dpnp.is_supported_array_or_scalar(value) + + @pytest.mark.parametrize( + "array", + [ + [1, 2, 3], + (1, 2, 3), + None, + numpy.array([1, 2, 3]), + ], + ) + def test_invalid_arrays(self, array): + assert not dpnp.is_supported_array_or_scalar(array) + + +class TestSynchronizeArrayData: + @pytest.mark.parametrize( + "array", + [ + dpnp.array([1, 2, 3]), + dpt.asarray([1, 2, 3]), + ], + ) + def test_synchronize_array_data(self, array): + try: + dpnp.synchronize_array_data(array) + except Exception as e: + pytest.fail(f"synchronize_array_data failed: {e}") + + @pytest.mark.parametrize( + "input", + [ + [1, 2, 3], + numpy.array([1, 2, 3]), + ], + ) + def test_unsupported_type(self, input): + with pytest.raises(TypeError): + dpnp.synchronize_array_data(input) From f45cf8fceb3cfec60671afcbc350fa155244dfc1 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Fri, 24 Jan 2025 05:35:02 -0800 Subject: [PATCH 3/4] Apply comments --- dpnp/tests/test_mathematical.py | 24 ------------------------ dpnp/tests/test_ndarray.py | 26 ++++++++++++++++++++++++++ dpnp/tests/test_utils.py | 10 ++++++---- 3 files changed, 32 insertions(+), 28 deletions(-) diff --git a/dpnp/tests/test_mathematical.py b/dpnp/tests/test_mathematical.py index 9c280b45758b..500316ab21f3 100644 --- a/dpnp/tests/test_mathematical.py +++ b/dpnp/tests/test_mathematical.py @@ -3320,30 +3320,6 @@ def test_matmul_with_offsets(self, sh1, sh2): expected = a[1] @ b[1] assert_dtype_allclose(result, expected) - def test_rmatmul_dpnp_array(self): - a = dpnp.ones(10) - b = dpnp.ones(10) - - class A(dpnp.ndarray): - def __init__(self, x): - self._array_obj = x.get_array() - - def __matmul__(self, other): - return NotImplemented - - a1 = A(a) - - result = a1 @ b - expected = dpnp.matmul(a, b) - assert_dtype_allclose(result, expected) - - def test_rmatmul_numpy_array(self): - a = dpnp.ones(10) - b = numpy.ones(10) - - with pytest.raises(TypeError): - b @ a - class TestMatmulInplace: ALL_DTYPES = get_all_dtypes(no_none=True) diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index 0a394c291585..e66da4fc0129 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -460,3 +460,29 @@ def test_clip(): expected = numpy.clip(numpy_array, 3, 7) assert_array_equal(expected, result) + + +def test_rmatmul_dpnp_array(): + a = dpnp.ones(10) + b = dpnp.ones(10) + + class Dummy(dpnp.ndarray): + def __init__(self, x): + self._array_obj = x.get_array() + + def __matmul__(self, other): + return NotImplemented + + d = Dummy(a) + + result = d @ b + expected = dpnp.matmul(a, b) + assert_allclose(result, expected) + + +def test_rmatmul_numpy_array(): + a = dpnp.ones(10) + b = numpy.ones(10) + + with pytest.raises(TypeError): + b @ a diff --git a/dpnp/tests/test_utils.py b/dpnp/tests/test_utils.py index 24743dc49717..89e97b75d5e5 100644 --- a/dpnp/tests/test_utils.py +++ b/dpnp/tests/test_utils.py @@ -1,3 +1,4 @@ +import dpctl import dpctl.tensor as dpt import numpy import pytest @@ -15,7 +16,7 @@ class TestIsSupportedArrayOrScalar: ], ) def test_valid_arrays(self, array): - assert dpnp.is_supported_array_or_scalar(array) + assert dpnp.is_supported_array_or_scalar(array) is True @pytest.mark.parametrize( "value", @@ -26,7 +27,7 @@ def test_valid_arrays(self, array): ], ) def test_valid_scalars(self, value): - assert dpnp.is_supported_array_or_scalar(value) + assert dpnp.is_supported_array_or_scalar(value) is True @pytest.mark.parametrize( "array", @@ -38,7 +39,7 @@ def test_valid_scalars(self, value): ], ) def test_invalid_arrays(self, array): - assert not dpnp.is_supported_array_or_scalar(array) + assert not dpnp.is_supported_array_or_scalar(array) is True class TestSynchronizeArrayData: @@ -50,8 +51,9 @@ class TestSynchronizeArrayData: ], ) def test_synchronize_array_data(self, array): + a_copy = dpnp.copy(array, sycl_queue=array.sycl_queue) try: - dpnp.synchronize_array_data(array) + dpnp.synchronize_array_data(a_copy) except Exception as e: pytest.fail(f"synchronize_array_data failed: {e}") From cb02b13bfd84163ed3d7dfd60021e0e16114a1c6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Tue, 28 Jan 2025 03:47:03 -0800 Subject: [PATCH 4/4] Apply remarks --- dpnp/tests/test_ndarray.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dpnp/tests/test_ndarray.py b/dpnp/tests/test_ndarray.py index e66da4fc0129..dd790c8e56a1 100644 --- a/dpnp/tests/test_ndarray.py +++ b/dpnp/tests/test_ndarray.py @@ -476,8 +476,8 @@ def __matmul__(self, other): d = Dummy(a) result = d @ b - expected = dpnp.matmul(a, b) - assert_allclose(result, expected) + expected = a @ b + assert (result == expected).all() def test_rmatmul_numpy_array():