Skip to content

Commit 64ba877

Browse files
authored
fix cross & unskip tests for cross (#509)
* fix cross & unskip tests for cross
1 parent bc285d8 commit 64ba877

File tree

4 files changed

+31
-19
lines changed

4 files changed

+31
-19
lines changed

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ cpdef dparray dpnp_cross(dparray x1, dparray x2):
144144
(dpnp.float32, dpnp.float32): dpnp.float32,
145145
}
146146

147-
res_type = types_map.get((x1.dtype, x2.dtype), dpnp.float64)
147+
res_type = types_map.get((x1.dtype.type, x2.dtype.type), dpnp.float64)
148148

149149
cdef dparray result = dparray(3, dtype=res_type)
150150

dpnp/dpnp_iface_mathematical.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def copysign(x1, x2, **kwargs):
328328
return call_origin(numpy.copysign, x1, x2, **kwargs)
329329

330330

331-
def cross(x1, x2, **kwargs):
331+
def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
332332
"""
333333
Return the cross product of two (arrays of) vectors.
334334
@@ -354,7 +354,7 @@ def cross(x1, x2, **kwargs):
354354
355355
"""
356356

357-
if not use_origin_backend(x1) and not kwargs:
357+
if not use_origin_backend(x1):
358358
if not isinstance(x1, dparray):
359359
pass
360360
elif not isinstance(x2, dparray):
@@ -363,10 +363,18 @@ def cross(x1, x2, **kwargs):
363363
pass
364364
elif x1.shape != (3,) or x2.shape != (3,):
365365
pass
366+
elif axisa != -1:
367+
pass
368+
elif axisb != -1:
369+
pass
370+
elif axisc != -1:
371+
pass
372+
elif axis is not None:
373+
pass
366374
else:
367375
return dpnp_cross(x1, x2)
368376

369-
return call_origin(numpy.cross, x1, x2, **kwargs)
377+
return call_origin(numpy.cross, x1, x2, axisa, axisb, axisc, axis)
370378

371379

372380
def cumprod(x1, **kwargs):

tests/skipped_tests.tbl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -817,16 +817,6 @@ tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumLarge_param_9_{opt
817817
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_float
818818
tests/third_party/cupy/linalg_tests/test_einsum.py::TestEinSumUnaryOperationWithScalar::test_scalar_int
819819
tests/third_party/cupy/linalg_tests/test_einsum.py::TestListArgEinSumError::test_invalid_sub1
820-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_0_{params=((3,), (3,), -1, -1, -1)}::test_cross
821-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_1_{params=((1, 2), (1, 2), -1, -1, 1)}::test_cross
822-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_2_{params=((1, 3), (1, 3), 1, -1, -1)}::test_cross
823-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_3_{params=((1, 2), (1, 3), -1, -1, 1)}::test_cross
824-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_4_{params=((2, 2), (1, 3), -1, -1, 0)}::test_cross
825-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_5_{params=((3, 3), (1, 2), 0, -1, -1)}::test_cross
826-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_6_{params=((0, 3), (0, 3), -1, -1, -1)}::test_cross
827-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_7_{params=((2, 0, 3), (2, 0, 3), 0, 0, 0)}::test_cross
828-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_8_{params=((2, 4, 5, 3), (2, 4, 5, 3), -1, -1, 0)}::test_cross
829-
tests/third_party/cupy/linalg_tests/test_product.py::TestCrossProduct_param_9_{params=((2, 4, 5, 2), (2, 4, 5, 2), 0, 0, -1)}::test_cross
830820
tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_10_{shape=((4, 2), ()), trans_a=False, trans_b=True}::test_dot
831821
tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_11_{shape=((4, 2), ()), trans_a=False, trans_b=False}::test_dot
832822
tests/third_party/cupy/linalg_tests/test_product.py::TestDotFor0Dim_param_4_{shape=((), (2, 4)), trans_a=True, trans_b=True}::test_dot

tests/test_mathematical.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,19 +127,33 @@ def test_trapz_without_params(self, y_array, dx):
127127

128128
class TestCross:
129129

130+
@pytest.mark.parametrize("axis", [None, 0],
131+
ids=['None', '0'])
132+
@pytest.mark.parametrize("axisc", [-1, 0],
133+
ids=['-1', '0'])
134+
@pytest.mark.parametrize("axisb", [-1, 0],
135+
ids=['-1', '0'])
136+
@pytest.mark.parametrize("axisa", [-1, 0],
137+
ids=['-1', '0'])
130138
@pytest.mark.parametrize("x1", [[1, 2, 3],
131139
[1., 2.5, 6.],
132-
[2, 4, 6]])
140+
[2, 4, 6]],
141+
ids=['[1, 2, 3]',
142+
'[1., 2.5, 6.]',
143+
'[2, 4, 6]'])
133144
@pytest.mark.parametrize("x2", [[4, 5, 6],
134145
[1., 5., 2.],
135-
[6, 4, 3]])
136-
def test_cross_3x3(self, x1, x2):
146+
[6, 4, 3]],
147+
ids=['[4, 5, 6]',
148+
'[1., 5., 2.]',
149+
'[6, 4, 3]'])
150+
def test_cross_3x3(self, x1, x2, axisa, axisb, axisc, axis):
137151
x1_ = numpy.array(x1)
138152
ix1_ = inp.array(x1_)
139153

140154
x2_ = numpy.array(x2)
141155
ix2_ = inp.array(x2_)
142156

143-
result = inp.cross(ix1_, ix2_)
144-
expected = numpy.cross(x1_, x2_)
157+
result = inp.cross(ix1_, ix2_, axisa, axisb, axisc, axis)
158+
expected = numpy.cross(x1_, x2_, axisa, axisb, axisc, axis)
145159
numpy.testing.assert_array_equal(expected, result)

0 commit comments

Comments
 (0)