diff --git a/.github/workflows/array-api-skips.txt b/.github/workflows/array-api-skips.txt index bfbf71e906d4..097e6ffef42b 100644 --- a/.github/workflows/array-api-skips.txt +++ b/.github/workflows/array-api-skips.txt @@ -34,9 +34,5 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_asinh array_api_tests/test_signatures.py::test_func_signature[std] array_api_tests/test_signatures.py::test_func_signature[var] -# wrong shape is returned -array_api_tests/test_linalg.py::test_vecdot -array_api_tests/test_linalg.py::test_linalg_vecdot - # arrays have different values array_api_tests/test_linalg.py::test_linalg_tensordot diff --git a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py index 7752da531a25..f90c95370cd4 100644 --- a/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py +++ b/dpnp/dpnp_utils/dpnp_utils_linearalgebra.py @@ -198,13 +198,14 @@ def _define_dim_flags(x, axis): """ Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot. x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one - except for one of them), for instance, if x.shape = (1, 1, 1, 2), - then x_is_1D = True + except for dimension at `axis`), for instance, if x.shape = (1, 1, 1, 2), + and axis=-1, then x_is_1D = True. x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one except for the last two of them), for instance, if x.shape = (1, 1, 3, 2), - then x_is_2D = True + then x_is_2D = True. x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance, - if x.shape = (3, 4, 1, 2), then x_base_is_1D = True + if x.shape = (3, 4, 1, 2), then x_base_is_1D = True. + """ x_shape = x.shape @@ -326,14 +327,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim): if x1_shape[-1] != x2_shape[-1]: _shape_error(x1_shape[-1], x2_shape[-1], "vecdot", err_msg=0) - _, x1_is_1D, _ = _define_dim_flags(x1, axis=-1) - _, x2_is_1D, _ = _define_dim_flags(x2, axis=-1) - if x1_ndim == 1 and x2_ndim == 1: result_shape = () - elif x1_is_1D: + elif x1_ndim == 1: result_shape = x2_shape[:-1] - elif x2_is_1D: + elif x2_ndim == 1: result_shape = x1_shape[:-1] else: # at least 2D if x1_ndim != x2_ndim: diff --git a/dpnp/tests/test_product.py b/dpnp/tests/test_product.py index eab4fe9f53ec..09dfa7d815f2 100644 --- a/dpnp/tests/test_product.py +++ b/dpnp/tests/test_product.py @@ -1000,6 +1000,8 @@ def setup_method(self): ((1, 4, 5), (3, 1, 5)), ((1, 1, 4, 5), (3, 1, 5)), ((1, 4, 5), (1, 3, 1, 5)), + ((2, 1), (1, 1, 1)), + ((1, 1, 3), (3,)), ], ) def test_basic(self, dtype, shape1, shape2):