Skip to content

Commit 1f0447b

Browse files
committed
Update manipulation_tests/test_transpose.py
1 parent 5b7f6d3 commit 1f0447b

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

dpnp/tests/third_party/cupy/manipulation_tests/test_transpose.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010

1111
class TestTranspose(unittest.TestCase):
12+
1213
@testing.numpy_cupy_array_equal()
1314
def test_moveaxis1(self, xp):
1415
a = testing.shaped_arange((2, 3, 4), xp)
@@ -52,6 +53,12 @@ def test_moveaxis_invalid1_2(self):
5253
with pytest.raises(AxisError):
5354
xp.moveaxis(a, [0, 1], [1, 3])
5455

56+
def test_moveaxis_invalid1_3(self):
57+
for xp in (numpy, cupy):
58+
a = testing.shaped_arange((2, 3, 4), xp)
59+
with pytest.raises(AxisError):
60+
xp.moveaxis(a, 0, 3)
61+
5562
# dim is too small
5663
def test_moveaxis_invalid2_1(self):
5764
for xp in (numpy, cupy):
@@ -158,12 +165,37 @@ def test_external_transpose(self, xp):
158165
a = testing.shaped_arange((2, 3, 4), xp)
159166
return xp.transpose(a, (-1, 0, 1))
160167

161-
@testing.numpy_cupy_array_equal()
162-
def test_external_transpose_5d(self, xp):
163-
a = testing.shaped_arange((2, 3, 4, 5, 6), xp)
164-
return xp.transpose(a, (1, 0, 3, 4, 2))
165-
166168
@testing.numpy_cupy_array_equal()
167169
def test_external_transpose_all(self, xp):
168170
a = testing.shaped_arange((2, 3, 4), xp)
169171
return xp.transpose(a)
172+
173+
174+
ARRAY_SHAPES_TO_TEST = (
175+
(5, 2),
176+
(5, 2, 3),
177+
(5, 2, 3, 4),
178+
)
179+
180+
181+
class TestMatrixTranspose:
182+
183+
@testing.with_requires("numpy>=2.0")
184+
def test_matrix_transpose_raises_error_for_1d(self):
185+
msg = "matrix transpose with ndim < 2 is undefined"
186+
arr = cupy.arange(48)
187+
with pytest.raises(ValueError, match=msg):
188+
arr.mT
189+
190+
@testing.numpy_cupy_array_equal()
191+
def test_matrix_transpose_equals_transpose_2d(self, xp):
192+
arr = xp.arange(48).reshape((6, 8))
193+
return arr
194+
195+
@testing.with_requires("numpy>=2.0")
196+
@pytest.mark.parametrize("shape", ARRAY_SHAPES_TO_TEST)
197+
@testing.numpy_cupy_array_equal()
198+
def test_matrix_transpose_equals_swapaxes(self, xp, shape):
199+
vec = xp.arange(shape[-1])
200+
arr = xp.broadcast_to(vec, shape)
201+
return arr.mT

0 commit comments

Comments
 (0)