@@ -157,6 +157,55 @@ def test_2d(self):
157157 create_diagonal (xp .asarray ([[1 ]]), xp = xp )
158158
159159
160+ class TestExpandDims :
161+ def test_functionality (self ):
162+ def _squeeze_all (b : Array ) -> Array :
163+ """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
164+ for axis in range (b .ndim ):
165+ with contextlib .suppress (ValueError ):
166+ b = xp .squeeze (b , axis = axis )
167+ return b
168+
169+ s = (2 , 3 , 4 , 5 )
170+ a = xp .empty (s )
171+ for axis in range (- 5 , 4 ):
172+ b = expand_dims (a , axis = axis , xp = xp )
173+ assert b .shape [axis ] == 1
174+ assert _squeeze_all (b ).shape == s
175+
176+ def test_axis_tuple (self ):
177+ a = xp .empty ((3 , 3 , 3 ))
178+ assert expand_dims (a , axis = (0 , 1 , 2 ), xp = xp ).shape == (1 , 1 , 1 , 3 , 3 , 3 )
179+ assert expand_dims (a , axis = (0 , - 1 , - 2 ), xp = xp ).shape == (1 , 3 , 3 , 3 , 1 , 1 )
180+ assert expand_dims (a , axis = (0 , 3 , 5 ), xp = xp ).shape == (1 , 3 , 3 , 1 , 3 , 1 )
181+ assert expand_dims (a , axis = (0 , - 3 , - 5 ), xp = xp ).shape == (1 , 1 , 3 , 1 , 3 , 3 )
182+
183+ def test_axis_out_of_range (self ):
184+ s = (2 , 3 , 4 , 5 )
185+ a = xp .empty (s )
186+ with pytest .raises (IndexError , match = "out of bounds" ):
187+ expand_dims (a , axis = - 6 , xp = xp )
188+ with pytest .raises (IndexError , match = "out of bounds" ):
189+ expand_dims (a , axis = 5 , xp = xp )
190+
191+ a = xp .empty ((3 , 3 , 3 ))
192+ with pytest .raises (IndexError , match = "out of bounds" ):
193+ expand_dims (a , axis = (0 , - 6 ), xp = xp )
194+ with pytest .raises (IndexError , match = "out of bounds" ):
195+ expand_dims (a , axis = (0 , 5 ), xp = xp )
196+
197+ def test_repeated_axis (self ):
198+ a = xp .empty ((3 , 3 , 3 ))
199+ with pytest .raises (ValueError , match = "Duplicate dimensions" ):
200+ expand_dims (a , axis = (1 , 1 ), xp = xp )
201+
202+ def test_positive_negative_repeated (self ):
203+ # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
204+ a = xp .empty ((2 , 3 , 4 , 5 ))
205+ with pytest .raises (ValueError , match = "Duplicate dimensions" ):
206+ expand_dims (a , axis = (3 , - 3 ), xp = xp )
207+
208+
160209class TestKron :
161210 def test_basic (self ):
162211 # Using 0-dimensional array
@@ -222,55 +271,6 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
222271 assert_equal (k .shape , expected_shape , err_msg = "Unexpected shape from kron" )
223272
224273
225- class TestExpandDims :
226- def test_functionality (self ):
227- def _squeeze_all (b : Array ) -> Array :
228- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
229- for axis in range (b .ndim ):
230- with contextlib .suppress (ValueError ):
231- b = xp .squeeze (b , axis = axis )
232- return b
233-
234- s = (2 , 3 , 4 , 5 )
235- a = xp .empty (s )
236- for axis in range (- 5 , 4 ):
237- b = expand_dims (a , axis = axis , xp = xp )
238- assert b .shape [axis ] == 1
239- assert _squeeze_all (b ).shape == s
240-
241- def test_axis_tuple (self ):
242- a = xp .empty ((3 , 3 , 3 ))
243- assert expand_dims (a , axis = (0 , 1 , 2 ), xp = xp ).shape == (1 , 1 , 1 , 3 , 3 , 3 )
244- assert expand_dims (a , axis = (0 , - 1 , - 2 ), xp = xp ).shape == (1 , 3 , 3 , 3 , 1 , 1 )
245- assert expand_dims (a , axis = (0 , 3 , 5 ), xp = xp ).shape == (1 , 3 , 3 , 1 , 3 , 1 )
246- assert expand_dims (a , axis = (0 , - 3 , - 5 ), xp = xp ).shape == (1 , 1 , 3 , 1 , 3 , 3 )
247-
248- def test_axis_out_of_range (self ):
249- s = (2 , 3 , 4 , 5 )
250- a = xp .empty (s )
251- with pytest .raises (IndexError , match = "out of bounds" ):
252- expand_dims (a , axis = - 6 , xp = xp )
253- with pytest .raises (IndexError , match = "out of bounds" ):
254- expand_dims (a , axis = 5 , xp = xp )
255-
256- a = xp .empty ((3 , 3 , 3 ))
257- with pytest .raises (IndexError , match = "out of bounds" ):
258- expand_dims (a , axis = (0 , - 6 ), xp = xp )
259- with pytest .raises (IndexError , match = "out of bounds" ):
260- expand_dims (a , axis = (0 , 5 ), xp = xp )
261-
262- def test_repeated_axis (self ):
263- a = xp .empty ((3 , 3 , 3 ))
264- with pytest .raises (ValueError , match = "Duplicate dimensions" ):
265- expand_dims (a , axis = (1 , 1 ), xp = xp )
266-
267- def test_positive_negative_repeated (self ):
268- # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
269- a = xp .empty ((2 , 3 , 4 , 5 ))
270- with pytest .raises (ValueError , match = "Duplicate dimensions" ):
271- expand_dims (a , axis = (3 , - 3 ), xp = xp )
272-
273-
274274class TestSetDiff1D :
275275 def test_setdiff1d (self ):
276276 x1 = xp .asarray ([6 , 5 , 4 , 7 , 1 , 2 , 7 , 4 ])
0 commit comments