5454lazy_xp_function (setdiff1d , jax_jit = False )
5555lazy_xp_function (sinc )
5656
57+ NestedFloatList = list [float ] | list ["NestedFloatList" ]
58+
5759
5860class TestApplyWhere :
5961 @staticmethod
@@ -291,7 +293,31 @@ def test_0D(self, xp: ModuleType):
291293 y = atleast_nd (x , ndim = 5 )
292294 xp_assert_equal (y , xp .ones ((1 , 1 , 1 , 1 , 1 )))
293295
294- def test_1D (self , xp : ModuleType ):
296+ @pytest .mark .parametrize (
297+ ("input_shape" , "ndim" , "expected_shape" ),
298+ [
299+ ((1 ,), 0 , (1 ,)),
300+ ((5 ,), 1 , (5 ,)),
301+ ((2 ,), 2 , (1 , 2 )),
302+ ((3 ,), 3 , (1 , 1 , 3 )),
303+ ((2 ,), 5 , (1 , 1 , 1 , 1 , 2 )),
304+ ],
305+ )
306+ def test_1D_shapes (
307+ self ,
308+ input_shape : tuple [int ],
309+ ndim : int ,
310+ expected_shape : tuple [int ],
311+ xp : ModuleType ,
312+ ):
313+ n = math .prod (input_shape )
314+ x = xp .asarray (np .arange (n ).reshape (input_shape ))
315+ y = atleast_nd (x , ndim = ndim )
316+
317+ assert y .shape == expected_shape
318+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
319+
320+ def test_1D_values (self , xp : ModuleType ):
295321 x = xp .asarray ([0 , 1 ])
296322
297323 y = atleast_nd (x , ndim = 0 )
@@ -306,8 +332,32 @@ def test_1D(self, xp: ModuleType):
306332 y = atleast_nd (x , ndim = 5 )
307333 xp_assert_equal (y , xp .asarray ([[[[[0 , 1 ]]]]]))
308334
309- def test_2D (self , xp : ModuleType ):
310- x = xp .asarray ([[3.0 ]])
335+ @pytest .mark .parametrize (
336+ ("input_shape" , "ndim" , "expected_shape" ),
337+ [
338+ ((2 , 1 ), 0 , (2 , 1 )),
339+ ((5 , 2 ), 1 , (5 , 2 )),
340+ ((2 , 1 ), 2 , (2 , 1 )),
341+ ((3 , 1 ), 3 , (1 , 3 , 1 )),
342+ ((2 , 8 ), 5 , (1 , 1 , 1 , 2 , 8 )),
343+ ],
344+ )
345+ def test_2D_shapes (
346+ self ,
347+ input_shape : tuple [int ],
348+ ndim : int ,
349+ expected_shape : tuple [int ],
350+ xp : ModuleType ,
351+ ):
352+ n = math .prod (input_shape )
353+ x = xp .asarray (np .arange (n ).reshape (input_shape ))
354+ y = atleast_nd (x , ndim = ndim )
355+
356+ assert y .shape == expected_shape
357+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
358+
359+ def test_2D_values (self , xp : ModuleType ):
360+ x = xp .asarray ([[3.0 ], [4.0 ]])
311361
312362 y = atleast_nd (x , ndim = 0 )
313363 xp_assert_equal (y , x )
@@ -316,12 +366,36 @@ def test_2D(self, xp: ModuleType):
316366 xp_assert_equal (y , x )
317367
318368 y = atleast_nd (x , ndim = 3 )
319- xp_assert_equal (y , 3 * xp .ones (( 1 , 1 , 1 ) ))
369+ xp_assert_equal (y , xp .asarray ([[[ 3.0 ], [ 4.0 ]]] ))
320370
321371 y = atleast_nd (x , ndim = 5 )
322- xp_assert_equal (y , 3 * xp .ones ((1 , 1 , 1 , 1 , 1 )))
372+ xp_assert_equal (y , xp .asarray ([[[[[3.0 ], [4.0 ]]]]]))
373+
374+ @pytest .mark .parametrize (
375+ ("input_shape" , "ndim" , "expected_shape" ),
376+ [
377+ ((2 , 1 , 1 ), 0 , (2 , 1 , 1 )),
378+ ((1 , 5 , 2 ), 1 , (1 , 5 , 2 )),
379+ ((2 , 1 , 1 ), 2 , (2 , 1 , 1 )),
380+ ((1 , 3 , 1 ), 3 , (1 , 3 , 1 )),
381+ ((2 , 8 , 1 ), 5 , (1 , 1 , 2 , 8 , 1 )),
382+ ],
383+ )
384+ def test_3D_shapes (
385+ self ,
386+ input_shape : tuple [int ],
387+ ndim : int ,
388+ expected_shape : tuple [int ],
389+ xp : ModuleType ,
390+ ):
391+ n = math .prod (input_shape )
392+ x = xp .asarray (np .arange (n ).reshape (input_shape ))
393+ y = atleast_nd (x , ndim = ndim )
394+
395+ assert y .shape == expected_shape
396+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
323397
324- def test_3D (self , xp : ModuleType ):
398+ def test_3D_values (self , xp : ModuleType ):
325399 x = xp .asarray ([[[3.0 ], [2.0 ]]])
326400
327401 y = atleast_nd (x , ndim = 0 )
@@ -336,8 +410,32 @@ def test_3D(self, xp: ModuleType):
336410 y = atleast_nd (x , ndim = 5 )
337411 xp_assert_equal (y , xp .asarray ([[[[[3.0 ], [2.0 ]]]]]))
338412
339- def test_5D (self , xp : ModuleType ):
340- x = xp .ones ((1 , 1 , 1 , 1 , 1 ))
413+ @pytest .mark .parametrize (
414+ ("input_shape" , "ndim" , "expected_shape" ),
415+ [
416+ ((2 , 1 , 1 , 2 , 1 ), 0 , (2 , 1 , 1 , 2 , 1 )),
417+ ((1 , 5 , 2 , 3 , 2 ), 2 , (1 , 5 , 2 , 3 , 2 )),
418+ ((2 , 1 , 1 , 5 , 2 ), 5 , (2 , 1 , 1 , 5 , 2 )),
419+ ((1 , 3 , 1 , 2 , 1 ), 6 , (1 , 1 , 3 , 1 , 2 , 1 )),
420+ ((2 , 8 , 1 , 9 , 8 ), 9 , (1 , 1 , 1 , 1 , 2 , 8 , 1 , 9 , 8 )),
421+ ],
422+ )
423+ def test_5D_shapes (
424+ self ,
425+ input_shape : tuple [int ],
426+ ndim : int ,
427+ expected_shape : tuple [int ],
428+ xp : ModuleType ,
429+ ):
430+ n = math .prod (input_shape )
431+ x = xp .asarray (np .arange (n ).reshape (input_shape ))
432+ y = atleast_nd (x , ndim = ndim )
433+
434+ assert y .shape == expected_shape
435+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
436+
437+ def test_5D_values (self , xp : ModuleType ):
438+ x = xp .asarray ([[[[[3.0 ]], [[2.0 ]]]]])
341439
342440 y = atleast_nd (x , ndim = 0 )
343441 xp_assert_equal (y , x )
@@ -349,19 +447,10 @@ def test_5D(self, xp: ModuleType):
349447 xp_assert_equal (y , x )
350448
351449 y = atleast_nd (x , ndim = 6 )
352- xp_assert_equal (y , xp .ones (( 1 , 1 , 1 , 1 , 1 , 1 ) ))
450+ xp_assert_equal (y , xp .asarray ([[[[[[ 3.0 ]], [[ 2.0 ]]]]]] ))
353451
354452 y = atleast_nd (x , ndim = 9 )
355- xp_assert_equal (y , xp .ones ((1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 )))
356-
357- def test_device (self , xp : ModuleType , device : Device ):
358- x = xp .asarray ([1 , 2 , 3 ], device = device )
359- assert get_device (atleast_nd (x , ndim = 2 )) == device
360-
361- def test_xp (self , xp : ModuleType ):
362- x = xp .asarray (1.0 )
363- y = atleast_nd (x , ndim = 1 , xp = xp )
364- xp_assert_equal (y , xp .ones ((1 ,)))
453+ xp_assert_equal (y , xp .asarray ([[[[[[[[[3.0 ]], [[2.0 ]]]]]]]]]))
365454
366455
367456class TestBroadcastShapes :
0 commit comments