Skip to content

Commit 5b7f6d3

Browse files
committed
Update manipulation_tests/test_shape.py
1 parent 15988e2 commit 5b7f6d3

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

dpnp/tests/third_party/cupy/manipulation_tests/test_shape.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
@pytest.mark.parametrize("shape", [(2, 3), (), (4,)])
1010
class TestShape:
11+
1112
def test_shape(self, shape):
1213
for xp in (numpy, cupy):
1314
a = testing.shaped_arange(shape, xp)
@@ -20,10 +21,13 @@ def test_shape_list(self, shape):
2021

2122

2223
class TestReshape:
23-
def test_reshape_shapes(self):
24+
25+
def test_reshape_strides(self):
2426
def func(xp):
2527
a = testing.shaped_arange((1, 1, 1, 2, 2), xp)
26-
return a.shape
28+
if xp is cupy:
29+
return tuple(el * a.itemsize for el in a.strides)
30+
return a.strides
2731

2832
assert func(numpy) == func(cupy)
2933

@@ -98,15 +102,21 @@ def test_reshape_zerosize_invalid_unknown(self):
98102
def test_reshape_zerosize(self, xp):
99103
a = xp.zeros((0,))
100104
b = a.reshape((0,))
101-
# assert b.base is a
105+
if xp is cupy:
106+
assert a.get_array()._pointer == b.get_array()._pointer
107+
else:
108+
assert b.base is a
102109
return b
103110

104111
@testing.for_orders("CFA")
105112
@testing.numpy_cupy_array_equal(type_check=has_support_aspect64())
106113
def test_reshape_zerosize2(self, xp, order):
107114
a = xp.zeros((2, 0, 3))
108115
b = a.reshape((5, 0, 4), order=order)
109-
# assert b.base is a
116+
if xp is cupy:
117+
assert a.get_array()._pointer == b.get_array()._pointer
118+
else:
119+
assert b.base is a
110120
return b
111121

112122
@testing.for_orders("CFA")
@@ -141,6 +151,7 @@ def test_ndim_limit2(self, dtype, order):
141151

142152

143153
class TestRavel:
154+
144155
@testing.for_orders("CFA")
145156
# order = 'K' is not supported currently
146157
@testing.numpy_cupy_array_equal()
@@ -233,6 +244,7 @@ def test_external_ravel(self, xp):
233244
],
234245
)
235246
class TestReshapeOrder:
247+
236248
def test_reshape_contiguity(self, order_init, order_reshape, shape_in_out):
237249
shape_init, shape_final = shape_in_out
238250

@@ -247,4 +259,5 @@ def test_reshape_contiguity(self, order_init, order_reshape, shape_in_out):
247259
assert b_cupy.flags.f_contiguous == b_numpy.flags.f_contiguous
248260
assert b_cupy.flags.c_contiguous == b_numpy.flags.c_contiguous
249261

262+
# testing.assert_array_equal(b_cupy.strides, b_numpy.strides)
250263
testing.assert_array_equal(b_cupy, b_numpy)

0 commit comments

Comments
 (0)