Skip to content

Commit 15988e2

Browse files
committed
Update manipulation_tests/test_rearrange.py
1 parent 7ec0856 commit 15988e2

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

dpnp/tests/third_party/cupy/manipulation_tests/test_rearrange.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
{"shape": (5, 2), "shift": (2, 1, 3), "axis": None},
2727
)
2828
class TestRoll(unittest.TestCase):
29+
2930
@testing.for_all_dtypes()
3031
@testing.numpy_cupy_array_equal()
3132
def test_roll(self, xp, dtype):
@@ -37,17 +38,26 @@ def test_roll(self, xp, dtype):
3738
def test_roll_cupy_shift(self, xp, dtype):
3839
x = testing.shaped_arange(self.shape, xp, dtype)
3940
shift = self.shift
41+
if xp is cupy:
42+
shift = cupy.array(shift)
4043
return xp.roll(x, shift, axis=self.axis)
4144

4245

4346
class TestRollTypeError(unittest.TestCase):
44-
# TODO: update, once dpctl#1857 is resolved
45-
@testing.with_requires("numpy<2.1.2") # done in numpy#27437
47+
48+
@testing.with_requires("numpy>=2.1.2")
49+
def test_roll_invalid_shift_castable(self):
50+
for xp in (numpy, cupy):
51+
x = testing.shaped_arange((5, 2), xp)
52+
# Weird but works due to `int` call
53+
xp.roll(x, "0", axis=0)
54+
55+
@testing.with_requires("numpy>=2.1.2")
4656
def test_roll_invalid_shift(self):
4757
for xp in (numpy, cupy):
4858
x = testing.shaped_arange((5, 2), xp)
49-
with pytest.raises(TypeError):
50-
xp.roll(x, "0", axis=0)
59+
with pytest.raises(ValueError):
60+
xp.roll(x, "a", axis=0)
5161

5262
def test_roll_invalid_axis_type(self):
5363
for xp in (numpy, cupy):
@@ -75,11 +85,14 @@ def test_roll_invalid_cupy_shift(self):
7585
for xp in (numpy, cupy):
7686
x = testing.shaped_arange(self.shape, xp)
7787
shift = self.shift
88+
if xp is cupy:
89+
shift = cupy.array(shift)
7890
with pytest.raises(ValueError):
7991
xp.roll(x, shift, axis=self.axis)
8092

8193

8294
class TestFliplr(unittest.TestCase):
95+
8396
@testing.for_all_dtypes()
8497
@testing.numpy_cupy_array_equal()
8598
def test_fliplr_2(self, xp, dtype):
@@ -101,6 +114,7 @@ def test_fliplr_insufficient_ndim(self, dtype):
101114

102115

103116
class TestFlipud(unittest.TestCase):
117+
104118
@testing.for_all_dtypes()
105119
@testing.numpy_cupy_array_equal()
106120
def test_flipud_1(self, xp, dtype):
@@ -122,6 +136,7 @@ def test_flipud_insufficient_ndim(self, dtype):
122136

123137

124138
class TestFlip(unittest.TestCase):
139+
125140
@testing.for_all_dtypes()
126141
@testing.numpy_cupy_array_equal()
127142
def test_flip_1(self, xp, dtype):
@@ -205,6 +220,7 @@ def test_flip_invalid_negative_axis(self, dtype):
205220

206221

207222
class TestRot90(unittest.TestCase):
223+
208224
@testing.for_all_dtypes()
209225
@testing.numpy_cupy_array_equal()
210226
def test_rot90_none(self, xp, dtype):

0 commit comments

Comments
 (0)