From 5602db56e9490279d59caf38f2f74eed94172923 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 21 Oct 2024 13:37:14 -0500 Subject: [PATCH] following NEP-50 for dpnp.einsum --- dpnp/dpnp_iface_statistics.py | 2 +- dpnp/dpnp_utils/dpnp_utils_einsum.py | 16 +++++++--------- tests/test_linalg.py | 12 +++++------- .../third_party/cupy/linalg_tests/test_einsum.py | 9 ++++----- 4 files changed, 17 insertions(+), 22 deletions(-) diff --git a/dpnp/dpnp_iface_statistics.py b/dpnp/dpnp_iface_statistics.py index bc7d323ee16f..03aa3a6516d6 100644 --- a/dpnp/dpnp_iface_statistics.py +++ b/dpnp/dpnp_iface_statistics.py @@ -370,7 +370,7 @@ def correlate(x1, x2, mode="valid"): ----------- Input arrays are supported as :obj:`dpnp.ndarray`. Size and shape of input arrays are supported to be equal. - Parameter `mode` is supported only with default value ``"valid``. + Parameter `mode` is supported only with default value ``"valid"``. Otherwise the function will be executed sequentially on CPU. Input array data types are limited by supported DPNP :ref:`Data types`. diff --git a/dpnp/dpnp_utils/dpnp_utils_einsum.py b/dpnp/dpnp_utils/dpnp_utils_einsum.py index af87419b062a..700a02b0a3e7 100644 --- a/dpnp/dpnp_utils/dpnp_utils_einsum.py +++ b/dpnp/dpnp_utils/dpnp_utils_einsum.py @@ -33,9 +33,8 @@ from dpctl.utils import ExecutionPlacementError import dpnp -from dpnp.dpnp_utils import get_usm_allocations - -from ..dpnp_array import dpnp_array +from dpnp.dpnp_array import dpnp_array +from dpnp.dpnp_utils import get_usm_allocations, map_dtype_to_device _einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" @@ -1027,17 +1026,16 @@ def dpnp_einsum( "Input and output allocation queues are not compatible" ) - result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype for id, a in enumerate(operands): if dpnp.isscalar(a): + scalar_dtype = map_dtype_to_device(type(a), exec_q.sycl_device) operands[id] = dpnp.array( - a, dtype=result_dtype, usm_type=res_usm_type, sycl_queue=exec_q + a, dtype=scalar_dtype, usm_type=res_usm_type, sycl_queue=exec_q ) + arrays.append(operands[id]) result_dtype = dpnp.result_type(*arrays) if dtype is None else dtype - if order in ["a", "A"]: - order = ( - "F" if not any(arr.flags.c_contiguous for arr in arrays) else "C" - ) + if order in "aA": + order = "F" if all(arr.flags.fnc for arr in arrays) else "C" input_subscripts = [ _parse_ellipsis_subscript(sub, idx, ndim=arr.ndim) diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 4018c5cdc08f..d80d62703f15 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -1139,14 +1139,12 @@ def check_einsum_sums(self, dtype, do_opt=False): result = inp.einsum(*args, dtype="?", casting="unsafe", optimize=do_opt) assert_dtype_allclose(result, expected) - # with an scalar, NumPy < 2.0.0 uses the other input arrays to - # determine the output type while for NumPy > 2.0.0 the scalar - # with default machine dtype is used to determine the output - # data type + # NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of + # the inputs is a scalar while NumPy < 2.0 does not if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": - check_type = True - else: check_type = False + else: + check_type = True a = numpy.arange(9, dtype=dtype) a_dp = inp.array(a) expected = numpy.einsum(",i->", 3, a) @@ -1712,7 +1710,7 @@ def test_broadcasting_dot_cases(self): def test_output_order(self): # Ensure output order is respected for optimize cases, the below - # conraction should yield a reshaped tensor view + # contraction should yield a reshaped tensor view a = inp.ones((2, 3, 5), order="F") b = inp.ones((4, 3), order="F") diff --git a/tests/third_party/cupy/linalg_tests/test_einsum.py b/tests/third_party/cupy/linalg_tests/test_einsum.py index fa2e28b6aa79..20107b7a3b30 100644 --- a/tests/third_party/cupy/linalg_tests/test_einsum.py +++ b/tests/third_party/cupy/linalg_tests/test_einsum.py @@ -475,13 +475,12 @@ def test_einsum_binary(self, xp, dtype_a, dtype_b): class TestEinSumBinaryOperationWithScalar: - # with an scalar, NumPy < 2.0.0 uses the other input arrays to determine - # the output type while for NumPy > 2.0.0 the scalar with default machine - # dtype is used to determine the output type + # NumPy >= 2.0 follows NEP-50 to determine the output dtype when one of + # the inputs is a scalar while NumPy < 2.0 does not if numpy.lib.NumpyVersion(numpy.__version__) < "2.0.0": - type_check = has_support_aspect64() - else: type_check = False + else: + type_check = has_support_aspect64() @testing.for_all_dtypes() @testing.numpy_cupy_allclose(contiguous_check=False, type_check=type_check)