Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/array-api-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,5 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_asinh
array_api_tests/test_signatures.py::test_func_signature[std]
array_api_tests/test_signatures.py::test_func_signature[var]

# wrong shape is returned
array_api_tests/test_linalg.py::test_vecdot
array_api_tests/test_linalg.py::test_linalg_vecdot

# arrays have different values
array_api_tests/test_linalg.py::test_linalg_tensordot
13 changes: 7 additions & 6 deletions dpnp/dpnp_utils/dpnp_utils_linearalgebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
"""
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
except for one of them), for instance, if x.shape = (1, 1, 1, 2),
then x_is_1D = True
except for dimension at `axis`), for instance, if x.shape = (1, 1, 1, 2),
and axis=-1, then x_is_1D = True.
x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
then x_is_2D = True
then x_is_2D = True.
x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.

"""

x_shape = x.shape
Expand Down Expand Up @@ -331,9 +332,9 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):

if x1_ndim == 1 and x2_ndim == 1:
result_shape = ()
elif x1_is_1D:
elif x1_ndim == 1:
result_shape = x2_shape[:-1]
elif x2_is_1D:
elif x2_ndim == 1:
result_shape = x1_shape[:-1]
else: # at least 2D
if x1_ndim != x2_ndim:
Expand Down
2 changes: 2 additions & 0 deletions dpnp/tests/test_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,8 @@ def setup_method(self):
((1, 4, 5), (3, 1, 5)),
((1, 1, 4, 5), (3, 1, 5)),
((1, 4, 5), (1, 3, 1, 5)),
((2, 1), (1, 1, 1)),
((1, 1, 3), (3,)),
],
)
def test_basic(self, dtype, shape1, shape2):
Expand Down
Loading