Skip to content

Commit 02082d3

Browse files
author
Vahid Tavanashad
committed
update test to pass with numpy-2.3
1 parent 2d27110 commit 02082d3

File tree

6 files changed

+27
-22
lines changed

6 files changed

+27
-22
lines changed

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ env:
2222
test-env-name: 'test'
2323
rerun-tests-on-failure: 'true'
2424
rerun-tests-max-attempts: 2
25-
rerun-tests-timeout: 35
25+
rerun-tests-timeout: 40
2626

2727
jobs:
2828
build:

dpnp/tests/test_linalg.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2104,11 +2104,9 @@ def test_empty(self, shape, ord, axis, keepdims):
21042104
assert_raises(ValueError, dpnp.linalg.norm, ia, **kwarg)
21052105
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
21062106
else:
2107-
# TODO: when similar changes in numpy are available, instead
2108-
# of assert_equal with zero, we should compare with numpy
2109-
# ord in [None, 1, 2]
2110-
assert_equal(dpnp.linalg.norm(ia, **kwarg), 0.0)
2111-
assert_raises(ValueError, numpy.linalg.norm, a, **kwarg)
2107+
result = dpnp.linalg.norm(ia, **kwarg)
2108+
expected = numpy.linalg.norm(a, **kwarg)
2109+
assert_dtype_allclose(result, expected)
21122110
else:
21132111
result = dpnp.linalg.norm(ia, **kwarg)
21142112
expected = numpy.linalg.norm(a, **kwarg)

dpnp/tests/test_product.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
assert_dtype_allclose,
1313
generate_random_numpy_array,
1414
get_all_dtypes,
15-
get_complex_dtypes,
16-
is_win_platform,
1715
numpy_version,
1816
)
19-
from .third_party.cupy import testing
17+
from .third_party.cupy.testing import with_requires
2018

2119
# A list of selected dtypes including both integer and float dtypes
2220
# to test differennt backends: OneMath (for float) and dpctl (for integer)
@@ -141,7 +139,7 @@ def test_strided(self, dtype, stride):
141139
expected = numpy.cross(a[::stride], b[::stride])
142140
assert_dtype_allclose(result, expected)
143141

144-
@testing.with_requires("numpy>=2.0")
142+
@with_requires("numpy>=2.0")
145143
@pytest.mark.parametrize("axis", [0, 1, -1])
146144
def test_linalg(self, axis):
147145
a = numpy.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
@@ -171,7 +169,7 @@ def test_error(self):
171169
with pytest.raises(TypeError):
172170
dpnp.cross(a, a)
173171

174-
@testing.with_requires("numpy>=2.0")
172+
@with_requires("numpy>=2.0")
175173
def test_linalg_error(self):
176174
a = dpnp.arange(4)
177175
b = dpnp.arange(4)
@@ -706,7 +704,7 @@ def test_axes_ND_ND(self, axes):
706704
expected = numpy.matmul(a, b, axes=axes)
707705
assert_dtype_allclose(result, expected)
708706

709-
@testing.with_requires("numpy>=2.2")
707+
@with_requires("numpy>=2.2")
710708
@pytest.mark.parametrize("func", ["matmul", "matvec"])
711709
@pytest.mark.parametrize(
712710
"axes",
@@ -725,7 +723,7 @@ def test_axes_ND_1D(self, func, axes):
725723
expected = getattr(numpy, func)(a, b, axes=axes)
726724
assert_dtype_allclose(result, expected)
727725

728-
@testing.with_requires("numpy>=2.2")
726+
@with_requires("numpy>=2.2")
729727
@pytest.mark.parametrize("func", ["matmul", "vecmat"])
730728
@pytest.mark.parametrize(
731729
"axes",
@@ -845,6 +843,8 @@ def test_dtype_matrix(self, dt_in1, dt_in2, dt_out, shape1, shape2):
845843
assert_raises(TypeError, dpnp.matmul, ia, ib, out=iout)
846844
assert_raises(TypeError, numpy.matmul, a, b, out=out)
847845

846+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
847+
@with_requires("numpy<2.3")
848848
@pytest.mark.parametrize("dtype", _selected_dtypes)
849849
@pytest.mark.parametrize("order1", ["C", "F", "A"])
850850
@pytest.mark.parametrize("order2", ["C", "F", "A"])
@@ -882,6 +882,8 @@ def test_order(self, dtype, order1, order2, order, shape1, shape2):
882882
assert result.flags.f_contiguous == expected.flags.f_contiguous
883883
assert_dtype_allclose(result, expected)
884884

885+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
886+
@with_requires("numpy<2.3")
885887
@pytest.mark.parametrize("dtype", _selected_dtypes)
886888
@pytest.mark.parametrize(
887889
"stride",
@@ -971,7 +973,7 @@ def test_strided3(self, dtype, stride, transpose):
971973
assert result is iout
972974
assert_dtype_allclose(result, expected)
973975

974-
@testing.with_requires("numpy>=2.2")
976+
@with_requires("numpy>=2.2")
975977
@pytest.mark.parametrize("dtype", _selected_dtypes)
976978
@pytest.mark.parametrize("func", ["matmul", "matvec"])
977979
@pytest.mark.parametrize("incx", [-2, 2])
@@ -1007,7 +1009,7 @@ def test_strided_mat_vec(self, dtype, func, incx, incy, transpose):
10071009
assert result is iout
10081010
assert_dtype_allclose(result, expected)
10091011

1010-
@testing.with_requires("numpy>=2.2")
1012+
@with_requires("numpy>=2.2")
10111013
@pytest.mark.parametrize("dtype", _selected_dtypes)
10121014
@pytest.mark.parametrize("func", ["matmul", "vecmat"])
10131015
@pytest.mark.parametrize("incx", [-2, 2])
@@ -1198,7 +1200,7 @@ def test_bool(self):
11981200
expected = numpy.matmul(a, b, out=out)
11991201
assert_dtype_allclose(result, expected)
12001202

1201-
@testing.with_requires("numpy>=2.0")
1203+
@with_requires("numpy>=2.0")
12021204
def test_linalg_matmul(self):
12031205
a = numpy.ones((3, 4))
12041206
b = numpy.ones((4, 5))
@@ -1434,7 +1436,7 @@ def test_invalid_axes(self, xp):
14341436
assert_raises(AxisError, xp.matmul, a, b, axes=axes)
14351437

14361438

1437-
@testing.with_requires("numpy>=2.2")
1439+
@with_requires("numpy>=2.2")
14381440
class TestMatvec:
14391441
def setup_method(self):
14401442
numpy.random.seed(42)
@@ -1735,7 +1737,7 @@ def test_strided(self, stride):
17351737
expected = numpy.tensordot(a, a, axes=axes)
17361738
assert_dtype_allclose(result, expected)
17371739

1738-
@testing.with_requires("numpy>=2.0")
1740+
@with_requires("numpy>=2.0")
17391741
@pytest.mark.parametrize(
17401742
"axes",
17411743
[([0, 1]), ([0, 1], [1, 2]), ([-2, -3], [3, 2])],
@@ -1880,7 +1882,7 @@ def test_error(self):
18801882
dpnp.vdot(b, a)
18811883

18821884

1883-
@testing.with_requires("numpy>=2.0")
1885+
@with_requires("numpy>=2.0")
18841886
class TestVecdot:
18851887
def setup_method(self):
18861888
numpy.random.seed(42)
@@ -2160,7 +2162,7 @@ def test_error(self, xp):
21602162
assert_raises(ValueError, xp.vecdot, a, a, axes=axes)
21612163

21622164

2163-
@testing.with_requires("numpy>=2.2")
2165+
@with_requires("numpy>=2.2")
21642166
class TestVecmat:
21652167
def setup_method(self):
21662168
numpy.random.seed(42)

dpnp/tests/testing/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
from dpnp.dpnp_utils import convert_item
3131

32+
from ..helper import numpy_version
33+
3234
assert_allclose_orig = numpy.testing.assert_allclose
3335
assert_almost_equal_orig = numpy.testing.assert_almost_equal
3436
assert_array_almost_equal_orig = numpy.testing.assert_array_almost_equal
@@ -49,7 +51,7 @@ def _assert(assert_func, result, expected, *args, **kwargs):
4951
]
5052
# For numpy < 2.0, some tests will fail for dtype mismatch
5153
dev = dpctl.select_default_device()
52-
if numpy.__version__ >= "2.0.0" and dev.has_aspect_fp64:
54+
if numpy_version() >= "2.0.0" and dev.has_aspect_fp64:
5355
strict = kwargs.setdefault("strict", True)
5456
if flag:
5557
if strict:

dpnp/tests/third_party/cupy/manipulation_tests/test_add_remove.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def test_unique_inverse(self, xp, dtype, attr):
340340
a = testing.shaped_random((100, 100), xp, dtype)
341341
return getattr(xp.unique_inverse(a), attr)
342342

343-
@testing.with_requires("numpy>=2.0")
343+
# TODO: include numpy-2.3 when dpnp-issue-2476 is addressed
344+
@testing.with_requires("numpy>=2.0", "numpy<2.3")
344345
@testing.for_all_dtypes(no_float16=True, no_bool=True, no_complex=True)
345346
@testing.numpy_cupy_array_equal()
346347
def test_unique_values(self, xp, dtype):

dpnp/tests/third_party/cupy/math_tests/test_matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def test_cupy_matmul(self, xp, dtype1, dtype2):
8383
return xp.matmul(x1, x2)
8484

8585

86+
# TODO: include numpy-2.3 when numpy-issue-29164 is resolved
87+
@testing.with_requires("numpy<2.3")
8688
@testing.parameterize(
8789
*testing.product(
8890
{

0 commit comments

Comments
 (0)