Skip to content

Commit 66504da

Browse files
dot/matmul fixes in python/cython (#922)
* dot/matmul fixes in python/sython * skip SVD tests with complex128 type
1 parent f7fb92c commit 66504da

File tree

4 files changed

+21
-17
lines changed

4 files changed

+21
-17
lines changed

dpnp/dpnp_algo/dpnp_algo_linearalgebra.pyx

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -244,18 +244,7 @@ cpdef utils.dpnp_descriptor dpnp_matmul(utils.dpnp_descriptor in_array1, utils.d
244244
"""
245245
size_n = 1
246246

247-
if ndim_max > 2:
248-
"""
249-
shape1(5, 3, 2) * shape2(5, 2, 4) -> result(5, 3, 4)
250-
test: pytest tests/test_matmul.py::test_matmul[shape_pair10-types0] -v -s
251-
"""
252-
shape_result = shape1[:-1] + [shape2.back()]
253-
else:
254-
"""
255-
shape1(5,2) * shape2(2,3) -> result(5,3)
256-
test: pytest tests/test_matmul.py::test_matmul[shape_pair0-types0] -v -s
257-
"""
258-
shape_result = shape1[:-1] + shape2[1:]
247+
shape_result = shape1[:-1] + shape2[-1:]
259248

260249
# convert string type names (array.dtype) to C enum DPNPFuncType
261250
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,18 @@ def dot(x1, x2, **kwargs):
9797
dim1 = x1_desc.ndim
9898
dim2 = x2_desc.ndim
9999

100-
if not (dim1 >= 2 and dim2 == 1) and not (dim1 >= 2 and dim2 >= 2) and (x1_desc.dtype == x2_desc.dtype):
100+
# for now we work only with these cases
101+
if (
102+
(dim1 == 1 and dim2 == 1 # vectors
103+
or dim1 == 2 and dim2 == 2 # matrices
104+
# or dim1 == 0 or dim2 == 0 # there is an issue with scalars (dpnp_multiply)
105+
) and (x1_desc.dtype == x2_desc.dtype)):
101106
result_obj = dpnp_dot(x1_desc, x2_desc).get_pyobj()
102-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
103-
104-
return result
107+
if (dim1 == 2 and dim2 == 2):
108+
return result_obj
109+
else:
110+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
111+
return result
105112

106113
return call_origin(numpy.dot, x1, x2, **kwargs)
107114

@@ -246,7 +253,7 @@ def matmul(x1, x2, out=None, **kwargs):
246253
x1_desc = dpnp.get_dpnp_descriptor(x1)
247254
x2_desc = dpnp.get_dpnp_descriptor(x2)
248255
if x1_desc and x2_desc and not kwargs:
249-
if x1_desc.size != x2_desc.size:
256+
if x1_desc.ndim != 2 or x2_desc.ndim != 2:
250257
pass
251258
elif not x1_desc.ndim:
252259
pass

tests/skipped_tests.tbl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ tests/test_linalg.py::test_cond[None-[[1, 0, -1], [0, 1, 0], [1, 0, 1]]]
1313
tests/test_linalg.py::test_cond[None-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
1414
tests/test_linalg.py::test_cond[-numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
1515
tests/test_linalg.py::test_cond[numpy.inf-[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]
16+
tests/test_linalg.py::test_svd[(2,2)-complex128]
17+
tests/test_linalg.py::test_svd[(3,4)-complex128]
18+
tests/test_linalg.py::test_svd[(5,3)-complex128]
19+
tests/test_linalg.py::test_svd[(16,16)-complex128]
1620
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))]
1721
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
1822
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]

tests/skipped_tests_gpu.tbl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ tests/test_linalg.py::test_svd[(5,3)-float32]
3838
tests/test_linalg.py::test_svd[(5,3)-float64]
3939
tests/test_linalg.py::test_svd[(5,3)-int32]
4040
tests/test_linalg.py::test_svd[(5,3)-int64]
41+
tests/test_linalg.py::test_svd[(2,2)-complex128]
42+
tests/test_linalg.py::test_svd[(3,4)-complex128]
43+
tests/test_linalg.py::test_svd[(5,3)-complex128]
44+
tests/test_linalg.py::test_svd[(16,16)-complex128]
4145
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: (dpnp.asarray([(i, i) for i in x], [("a", int), ("b", int)]).view(dpnp.recarray))]
4246
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray([(i, i) for i in x], [("a", object), ("b", dpnp.int32)])]]
4347
tests/test_random.py::TestPermutationsTestShuffle::test_shuffle1[lambda x: dpnp.asarray(x).astype(dpnp.int8)]

0 commit comments

Comments
 (0)