@@ -294,147 +294,163 @@ def test_0D(self, xp: ModuleType):
294294 xp_assert_equal (y , xp .ones ((1 , 1 , 1 , 1 , 1 )))
295295
296296 @pytest .mark .parametrize (
297- ("x_data " , "ndim" , "expected_data " ),
297+ ("input_shape " , "ndim" , "expected_shape " ),
298298 [
299- # --- size-1 vector ---
300- ([3.0 ], 0 , [3.0 ]),
301- ([3.0 ], 1 , [3.0 ]),
302- ([3.0 ], 2 , [[3.0 ]]),
303- ([3.0 ], 3 , [[[3.0 ]]]),
304- ([3.0 ], 5 , [[[[[3.0 ]]]]]),
305- # --- size-2 vector ---
306- ([0.0 , 1.0 ], 0 , [0.0 , 1.0 ]),
307- ([0.0 , 1.0 ], 1 , [0.0 , 1.0 ]),
308- ([0.0 , 1.0 ], 2 , [[0.0 , 1.0 ]]),
309- ([0.0 , 1.0 ], 5 , [[[[[0.0 , 1.0 ]]]]]),
299+ ((1 ,), 0 , (1 ,)),
300+ ((5 ,), 1 , (5 ,)),
301+ ((2 ,), 2 , (1 , 2 )),
302+ ((3 ,), 3 , (1 , 1 , 3 )),
303+ ((2 ,), 5 , (1 , 1 , 1 , 1 , 2 )),
310304 ],
311305 )
312- def test_1D (
306+ def test_1D_shapes (
313307 self ,
314- x_data : NestedFloatList ,
308+ input_shape : tuple [ int ] ,
315309 ndim : int ,
316- expected_data : NestedFloatList ,
310+ expected_shape : tuple [ int ] ,
317311 xp : ModuleType ,
318312 ):
319- x = xp . asarray ( x_data )
320- expected = xp .asarray (expected_data )
313+ n = math . prod ( input_shape )
314+ x = xp .reshape ( xp . asarray (list ( range ( n ))), input_shape )
321315 y = atleast_nd (x , ndim = ndim )
322- xp_assert_equal (y , expected )
323316
324- @pytest .mark .parametrize (
325- ("x_data" , "ndim" , "expected_data" ),
326- [
327- # --- size-1 vector ---
328- ([[3.0 ]], 0 , [[3.0 ]]),
329- ([[3.0 ]], 1 , [[3.0 ]]),
330- ([[3.0 ]], 2 , [[3.0 ]]),
331- ([[3.0 ]], 3 , [[[3.0 ]]]),
332- ([[3.0 ]], 5 , [[[[[3.0 ]]]]]),
333- # --- size-2 vector ---
334- ([[0.0 ], [1.0 ]], 0 , [[0.0 ], [1.0 ]]),
335- ([[0.0 , 1.0 ]], 1 , [[0.0 , 1.0 ]]),
336- ([[0.0 , 1.0 ]], 2 , [[0.0 , 1.0 ]]),
337- ([[0.0 ], [1.0 ]], 3 , [[[0.0 ], [1.0 ]]]),
338- ([[0.0 , 1.0 ]], 5 , [[[[[0.0 , 1.0 ]]]]]),
339- ],
340- )
341- def test_2D (
342- self ,
343- x_data : NestedFloatList ,
344- ndim : int ,
345- expected_data : NestedFloatList ,
346- xp : ModuleType ,
347- ):
348- x = xp .asarray (x_data )
349- expected = xp .asarray (expected_data )
350- y = atleast_nd (x , ndim = ndim )
351- xp_assert_equal (y , expected )
317+ assert y .shape == expected_shape
318+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
319+
320+ def test_1D_values (self , xp : ModuleType ):
321+ x = xp .asarray ([0 , 1 ])
322+
323+ y = atleast_nd (x , ndim = 0 )
324+ xp_assert_equal (y , x )
325+
326+ y = atleast_nd (x , ndim = 1 )
327+ xp_assert_equal (y , x )
328+
329+ y = atleast_nd (x , ndim = 2 )
330+ xp_assert_equal (y , xp .asarray ([[0 , 1 ]]))
331+
332+ y = atleast_nd (x , ndim = 5 )
333+ xp_assert_equal (y , xp .asarray ([[[[[0 , 1 ]]]]]))
352334
353335 @pytest .mark .parametrize (
354- ("x_data " , "ndim" , "expected_data " ),
336+ ("input_shape " , "ndim" , "expected_shape " ),
355337 [
356- ([[[ 0.0 ]], [[ 1.0 ]]] , 0 , [[[ 0.0 ]], [[ 1.0 ]]] ),
357- ([[[ 0.0 ], [ 1.0 ]]] , 1 , [[[ 0.0 ], [ 1.0 ]]] ),
358- ([[[ 0.0 , 1.0 ]]] , 2 , [[[ 0.0 , 1.0 ]]] ),
359- ([[[ 0.0 ]], [[ 1.0 ]]] , 3 , [[[ 0.0 ]], [[ 1.0 ]]] ),
360- ([[[ 0.0 ], [ 1.0 ]]] , 5 , [[[[[ 0.0 ], [ 1.0 ]]]]] ),
338+ (( 2 , 1 ) , 0 , ( 2 , 1 ) ),
339+ (( 5 , 2 ) , 1 , ( 5 , 2 ) ),
340+ (( 2 , 1 ) , 2 , ( 2 , 1 ) ),
341+ (( 3 , 1 ) , 3 , ( 1 , 3 , 1 ) ),
342+ (( 2 , 8 ) , 5 , ( 1 , 1 , 1 , 2 , 8 ) ),
361343 ],
362344 )
363- def test_3D (
345+ def test_2D_shapes (
364346 self ,
365- x_data : NestedFloatList ,
347+ input_shape : tuple [ int ] ,
366348 ndim : int ,
367- expected_data : NestedFloatList ,
349+ expected_shape : tuple [ int ] ,
368350 xp : ModuleType ,
369351 ):
370- x = xp . asarray ( x_data )
371- expected = xp .asarray (expected_data )
352+ n = math . prod ( input_shape )
353+ x = xp .reshape ( xp . asarray (list ( range ( n ))), input_shape )
372354 y = atleast_nd (x , ndim = ndim )
373- xp_assert_equal (y , expected )
355+
356+ assert y .shape == expected_shape
357+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
358+
359+ def test_2D_values (self , xp : ModuleType ):
360+ x = xp .asarray ([[3.0 ], [4.0 ]])
361+
362+ y = atleast_nd (x , ndim = 0 )
363+ xp_assert_equal (y , x )
364+
365+ y = atleast_nd (x , ndim = 2 )
366+ xp_assert_equal (y , x )
367+
368+ y = atleast_nd (x , ndim = 3 )
369+ xp_assert_equal (y , xp .asarray ([[[3.0 ], [4.0 ]]]))
370+
371+ y = atleast_nd (x , ndim = 5 )
372+ xp_assert_equal (y , xp .asarray ([[[[[3.0 ], [4.0 ]]]]]))
374373
375374 @pytest .mark .parametrize (
376- ("x_data " , "ndim" , "expected_data " ),
375+ ("input_shape " , "ndim" , "expected_shape " ),
377376 [
378- ([[[[3.0 ], [2.0 ]]]], 0 , [[[[3.0 ], [2.0 ]]]]),
379- ([[[[3.0 , 2.0 ]]]], 2 , [[[[3.0 , 2.0 ]]]]),
380- ([[[[3.0 ]], [[2.0 ]]]], 4 , [[[[3.0 ]], [[2.0 ]]]]),
381- ([[[[3.0 ]]], [[[2.0 ]]]], 5 , [[[[[3.0 ]]], [[[2.0 ]]]]]),
377+ ((2 , 1 , 1 ), 0 , (2 , 1 , 1 )),
378+ ((1 , 5 , 2 ), 1 , (1 , 5 , 2 )),
379+ ((2 , 1 , 1 ), 2 , (2 , 1 , 1 )),
380+ ((1 , 3 , 1 ), 3 , (1 , 3 , 1 )),
381+ ((2 , 8 , 1 ), 5 , (1 , 1 , 2 , 8 , 1 )),
382382 ],
383383 )
384- def test_4D (
384+ def test_3D_shapes (
385385 self ,
386- x_data : NestedFloatList ,
386+ input_shape : tuple [ int ] ,
387387 ndim : int ,
388- expected_data : NestedFloatList ,
388+ expected_shape : tuple [ int ] ,
389389 xp : ModuleType ,
390390 ):
391- x = xp . asarray ( x_data )
392- expected = xp .asarray (expected_data )
391+ n = math . prod ( input_shape )
392+ x = xp .reshape ( xp . asarray (list ( range ( n ))), input_shape )
393393 y = atleast_nd (x , ndim = ndim )
394- xp_assert_equal (y , expected )
394+
395+ assert y .shape == expected_shape
396+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
397+
398+ def test_3D_values (self , xp : ModuleType ):
399+ x = xp .asarray ([[[3.0 ], [2.0 ]]])
400+
401+ y = atleast_nd (x , ndim = 0 )
402+ xp_assert_equal (y , x )
403+
404+ y = atleast_nd (x , ndim = 2 )
405+ xp_assert_equal (y , x )
406+
407+ y = atleast_nd (x , ndim = 3 )
408+ xp_assert_equal (y , x )
409+
410+ y = atleast_nd (x , ndim = 5 )
411+ xp_assert_equal (y , xp .asarray ([[[[[3.0 ], [2.0 ]]]]]))
395412
396413 @pytest .mark .parametrize (
397- ("x_data " , "ndim" , "expected_data " ),
414+ ("input_shape " , "ndim" , "expected_shape " ),
398415 [
399- ([[[[[3.0 ]], [[2.0 ]], [[1.0 ]]]]], 0 , [[[[[3.0 ]], [[2.0 ]], [[1.0 ]]]]]),
400- ([[[[[3.0 , 2.0 , 6.0 ]]]]], 2 , [[[[[3.0 , 2.0 , 6.0 ]]]]]),
401- (
402- [[[[[3.0 ]]], [[[2.0 ]]], [[[1.0 ]]]]],
403- 4 ,
404- [[[[[3.0 ]]], [[[2.0 ]]], [[[1.0 ]]]]],
405- ),
406- (
407- [[[[[3.0 ]], [[1.0 ]]], [[[2.0 ]], [[1.0 ]]], [[[1.0 ]], [[1.0 ]]]]],
408- 6 ,
409- [[[[[[3.0 ]], [[1.0 ]]], [[[2.0 ]], [[1.0 ]]], [[[1.0 ]], [[1.0 ]]]]]],
410- ),
411- (
412- [[[[[3.0 ]], [[1.0 ]]], [[[2.0 ]], [[1.0 ]]], [[[1.0 ]], [[1.0 ]]]]],
413- 9 ,
414- [[[[[[[[[3.0 ]], [[1.0 ]]], [[[2.0 ]], [[1.0 ]]], [[[1.0 ]], [[1.0 ]]]]]]]]],
415- ),
416+ ((2 , 1 , 1 , 2 , 1 ), 0 , (2 , 1 , 1 , 2 , 1 )),
417+ ((1 , 5 , 2 , 3 , 2 ), 2 , (1 , 5 , 2 , 3 , 2 )),
418+ ((2 , 1 , 1 , 5 , 2 ), 5 , (2 , 1 , 1 , 5 , 2 )),
419+ ((1 , 3 , 1 , 2 , 1 ), 6 , (1 , 1 , 3 , 1 , 2 , 1 )),
420+ ((2 , 8 , 1 , 9 , 8 ), 9 , (1 , 1 , 1 , 1 , 2 , 8 , 1 , 9 , 8 )),
416421 ],
417422 )
418- def test_5D (
423+ def test_5D_shapes (
419424 self ,
420- x_data : NestedFloatList ,
425+ input_shape : tuple [ int ] ,
421426 ndim : int ,
422- expected_data : NestedFloatList ,
427+ expected_shape : tuple [ int ] ,
423428 xp : ModuleType ,
424429 ):
425- x = xp . asarray ( x_data )
426- expected = xp .asarray (expected_data )
430+ n = math . prod ( input_shape )
431+ x = xp .reshape ( xp . asarray (list ( range ( n ))), input_shape )
427432 y = atleast_nd (x , ndim = ndim )
428- xp_assert_equal (y , expected )
429433
430- def test_device (self , xp : ModuleType , device : Device ):
431- x = xp .asarray ([1 , 2 , 3 ], device = device )
432- assert get_device (atleast_nd (x , ndim = 2 )) == device
434+ assert y .shape == expected_shape
435+ assert xp .sum (y ) == int (n * (n - 1 ) / 2 )
433436
434- def test_xp (self , xp : ModuleType ):
435- x = xp .asarray (1.0 )
436- y = atleast_nd (x , ndim = 1 , xp = xp )
437- xp_assert_equal (y , xp .ones ((1 ,)))
437+ def test_5D_values (self , xp : ModuleType ):
438+ x = xp .asarray ([[[[[3.0 ]], [[2.0 ]]]]])
439+
440+ y = atleast_nd (x , ndim = 0 )
441+ xp_assert_equal (y , x )
442+
443+ y = atleast_nd (x , ndim = 4 )
444+ xp_assert_equal (y , x )
445+
446+ y = atleast_nd (x , ndim = 5 )
447+ xp_assert_equal (y , x )
448+
449+ y = atleast_nd (x , ndim = 6 )
450+ xp_assert_equal (y , xp .asarray ([[[[[[3.0 ]], [[2.0 ]]]]]]))
451+
452+ y = atleast_nd (x , ndim = 9 )
453+ xp_assert_equal (y , xp .asarray ([[[[[[[[[3.0 ]], [[2.0 ]]]]]]]]]))
438454
439455
440456class TestBroadcastShapes :
0 commit comments