88
99@pytest .mark .parametrize ("shape" , [(2 , 3 ), (), (4 ,)])
1010class TestShape :
11+
1112 def test_shape (self , shape ):
1213 for xp in (numpy , cupy ):
1314 a = testing .shaped_arange (shape , xp )
@@ -20,10 +21,13 @@ def test_shape_list(self, shape):
2021
2122
2223class TestReshape :
23- def test_reshape_shapes (self ):
24+
25+ def test_reshape_strides (self ):
2426 def func (xp ):
2527 a = testing .shaped_arange ((1 , 1 , 1 , 2 , 2 ), xp )
26- return a .shape
28+ if xp is cupy :
29+ return tuple (el * a .itemsize for el in a .strides )
30+ return a .strides
2731
2832 assert func (numpy ) == func (cupy )
2933
@@ -98,15 +102,21 @@ def test_reshape_zerosize_invalid_unknown(self):
98102 def test_reshape_zerosize (self , xp ):
99103 a = xp .zeros ((0 ,))
100104 b = a .reshape ((0 ,))
101- # assert b.base is a
105+ if xp is cupy :
106+ assert a .get_array ()._pointer == b .get_array ()._pointer
107+ else :
108+ assert b .base is a
102109 return b
103110
104111 @testing .for_orders ("CFA" )
105112 @testing .numpy_cupy_array_equal (type_check = has_support_aspect64 ())
106113 def test_reshape_zerosize2 (self , xp , order ):
107114 a = xp .zeros ((2 , 0 , 3 ))
108115 b = a .reshape ((5 , 0 , 4 ), order = order )
109- # assert b.base is a
116+ if xp is cupy :
117+ assert a .get_array ()._pointer == b .get_array ()._pointer
118+ else :
119+ assert b .base is a
110120 return b
111121
112122 @testing .for_orders ("CFA" )
@@ -141,6 +151,7 @@ def test_ndim_limit2(self, dtype, order):
141151
142152
143153class TestRavel :
154+
144155 @testing .for_orders ("CFA" )
145156 # order = 'K' is not supported currently
146157 @testing .numpy_cupy_array_equal ()
@@ -233,6 +244,7 @@ def test_external_ravel(self, xp):
233244 ],
234245)
235246class TestReshapeOrder :
247+
236248 def test_reshape_contiguity (self , order_init , order_reshape , shape_in_out ):
237249 shape_init , shape_final = shape_in_out
238250
@@ -247,4 +259,5 @@ def test_reshape_contiguity(self, order_init, order_reshape, shape_in_out):
247259 assert b_cupy .flags .f_contiguous == b_numpy .flags .f_contiguous
248260 assert b_cupy .flags .c_contiguous == b_numpy .flags .c_contiguous
249261
262+ # testing.assert_array_equal(b_cupy.strides, b_numpy.strides)
250263 testing .assert_array_equal (b_cupy , b_numpy )
0 commit comments