Skip to content

Commit f59daaf

Browse files
committed
Update linalg_tests/test_product.py
1 parent 6cc858c commit f59daaf

File tree

1 file changed

+85
-7
lines changed

1 file changed

+85
-7
lines changed

dpnp/tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import sys
12
import unittest
3+
import warnings
24

35
import numpy
46
import pytest
@@ -37,6 +39,7 @@
3739
)
3840
)
3941
class TestDot(unittest.TestCase):
42+
4043
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
4144
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
4245
def test_dot(self, xp, dtype_a, dtype_b):
@@ -87,32 +90,97 @@ def test_dot_with_out(self, xp, dtype_a, dtype_b, dtype_c):
8790
# Test for 0 dimension
8891
((3,), (3,), -1, -1, -1),
8992
# Test for basic cases
90-
((1, 2), (1, 2), -1, -1, 1),
9193
((1, 3), (1, 3), 1, -1, -1),
94+
# Test for higher dimensions
95+
((2, 4, 5, 3), (2, 4, 5, 3), -1, -1, 0),
96+
],
97+
}
98+
)
99+
)
100+
class TestCrossProduct(unittest.TestCase):
101+
102+
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
103+
@testing.numpy_cupy_allclose()
104+
def test_cross(self, xp, dtype_a, dtype_b):
105+
if dtype_a == dtype_b == numpy.bool_:
106+
# cross does not support bool-bool inputs.
107+
return xp.array(True)
108+
shape_a, shape_b, axisa, axisb, axisc = self.params
109+
a = testing.shaped_arange(shape_a, xp, dtype_a)
110+
b = testing.shaped_arange(shape_b, xp, dtype_b)
111+
return xp.cross(a, b, axisa, axisb, axisc)
112+
113+
114+
# XXX: cross with 2D vectors is deprecated in NumPy 2.0, also CuPy 1.14
115+
@testing.parameterize(
116+
*testing.product(
117+
{
118+
"params": [
119+
# Test for basic cases
120+
((1, 2), (1, 2), -1, -1, 1),
92121
((1, 2), (1, 3), -1, -1, 1),
93122
((2, 2), (1, 3), -1, -1, 0),
94123
((3, 3), (1, 2), 0, -1, -1),
95124
((0, 3), (0, 3), -1, -1, -1),
96125
# Test for higher dimensions
97126
((2, 0, 3), (2, 0, 3), 0, 0, 0),
98-
((2, 4, 5, 3), (2, 4, 5, 3), -1, -1, 0),
99127
((2, 4, 5, 2), (2, 4, 5, 2), 0, 0, -1),
100128
],
101129
}
102130
)
103131
)
104-
class TestCrossProduct(unittest.TestCase):
105-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
132+
class TestCrossProductDeprecated(unittest.TestCase):
106133
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
107-
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
134+
@testing.numpy_cupy_allclose()
108135
def test_cross(self, xp, dtype_a, dtype_b):
109136
if dtype_a == dtype_b == numpy.bool_:
110137
# cross does not support bool-bool inputs.
111138
return xp.array(True)
112139
shape_a, shape_b, axisa, axisb, axisc = self.params
113140
a = testing.shaped_arange(shape_a, xp, dtype_a)
114141
b = testing.shaped_arange(shape_b, xp, dtype_b)
115-
return xp.cross(a, b, axisa, axisb, axisc)
142+
143+
with warnings.catch_warnings():
144+
warnings.simplefilter("ignore", DeprecationWarning)
145+
res = xp.cross(a, b, axisa, axisb, axisc)
146+
return res
147+
148+
149+
@testing.parameterize(
150+
*testing.product(
151+
{
152+
"params": [
153+
# Test for 0 dimension
154+
(
155+
(3,),
156+
(3,),
157+
-1,
158+
),
159+
# Test for basic cases
160+
(
161+
(1, 3),
162+
(1, 3),
163+
1,
164+
),
165+
# Test for higher dimensions
166+
((2, 4, 5, 3), (2, 4, 5, 3), -1),
167+
],
168+
}
169+
)
170+
)
171+
class TestLinalgCrossProduct(unittest.TestCase):
172+
173+
@testing.with_requires("numpy>=2.0")
174+
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
175+
@testing.numpy_cupy_allclose()
176+
def test_cross(self, xp, dtype_a, dtype_b):
177+
if dtype_a == dtype_b == numpy.bool_:
178+
# cross does not support bool-bool inputs.
179+
return xp.array(True)
180+
shape_a, shape_b, axis = self.params
181+
a = testing.shaped_arange(shape_a, xp, dtype_a)
182+
b = testing.shaped_arange(shape_b, xp, dtype_b)
183+
return xp.linalg.cross(a, b, axis=axis)
116184

117185

118186
@testing.parameterize(
@@ -129,6 +197,7 @@ def test_cross(self, xp, dtype_a, dtype_b):
129197
)
130198
)
131199
class TestDotFor0Dim(unittest.TestCase):
200+
132201
@testing.for_all_dtypes_combination(["dtype_a", "dtype_b"])
133202
@testing.numpy_cupy_allclose(
134203
type_check=has_support_aspect64(), contiguous_check=False
@@ -147,6 +216,7 @@ def test_dot(self, xp, dtype_a, dtype_b):
147216

148217

149218
class TestProduct:
219+
150220
@testing.for_all_dtypes()
151221
@testing.numpy_cupy_allclose()
152222
def test_dot_vec1(self, xp, dtype):
@@ -403,7 +473,9 @@ def test_zerodim_kron(self, xp, dtype):
403473
)
404474
@testing.numpy_cupy_allclose(type_check=has_support_aspect64())
405475
def test_kron_accepts_numbers_as_arguments(self, a, b, xp):
406-
args = [xp.array(arg) if type(arg) == list else arg for arg in [a, b]]
476+
args = [
477+
xp.array(arg) if isinstance(arg, list) else arg for arg in [a, b]
478+
]
407479
return xp.kron(*args)
408480

409481

@@ -422,6 +494,7 @@ def test_kron_accepts_numbers_as_arguments(self, a, b, xp):
422494
)
423495
)
424496
class TestProductZeroLength(unittest.TestCase):
497+
425498
@testing.for_all_dtypes()
426499
@testing.numpy_cupy_allclose()
427500
def test_tensordot_zero_length(self, xp, dtype):
@@ -488,9 +561,13 @@ def test_matrix_power_large(self, xp, dtype):
488561
a = xp.eye(23, k=17, dtype=dtype) + xp.eye(23, k=-6, dtype=dtype)
489562
return xp.linalg.matrix_power(a, 123456789123456789)
490563

564+
@pytest.mark.skipif(
565+
sys.platform == "win32", reason="python int overflows C long"
566+
)
491567
@testing.for_float_dtypes(no_float16=True)
492568
@testing.numpy_cupy_allclose()
493569
def test_matrix_power_invlarge(self, xp, dtype):
570+
# TODO (ev-br): np 2.0: check if it's fixed in numpy 2 (broken on 1.26)
494571
a = xp.eye(23, k=17, dtype=dtype) + xp.eye(23, k=-6, dtype=dtype)
495572
return xp.linalg.matrix_power(a, -987654321987654321)
496573

@@ -504,6 +581,7 @@ def test_matrix_power_invlarge(self, xp, dtype):
504581
)
505582
@pytest.mark.parametrize("n", [0, 5, -7])
506583
class TestMatrixPowerBatched:
584+
507585
@testing.for_float_dtypes(no_float16=True)
508586
@testing.numpy_cupy_allclose(rtol=5e-5)
509587
def test_matrix_power_batched(self, xp, dtype, shape, n):

0 commit comments

Comments
 (0)