@@ -490,6 +490,38 @@ def test_stack():
490
490
stack ((a , b ), dtype = np .int64 , axis = 1 , casting = "safe" )
491
491
492
492
493
+ def test_unstack ():
494
+ a = np .arange (24 ).reshape ((2 , 3 , 4 ))
495
+
496
+ for stacks in [np .unstack (a ),
497
+ np .unstack (a , axis = 0 ),
498
+ np .unstack (a , axis = - 3 )]:
499
+ assert isinstance (stacks , tuple )
500
+ assert len (stacks ) == 2
501
+ assert_array_equal (stacks [0 ], a [0 ])
502
+ assert_array_equal (stacks [1 ], a [1 ])
503
+
504
+ for stacks in [np .unstack (a , axis = 1 ),
505
+ np .unstack (a , axis = - 2 )]:
506
+ assert isinstance (stacks , tuple )
507
+ assert len (stacks ) == 3
508
+ assert_array_equal (stacks [0 ], a [:, 0 ])
509
+ assert_array_equal (stacks [1 ], a [:, 1 ])
510
+ assert_array_equal (stacks [2 ], a [:, 2 ])
511
+
512
+ for stacks in [np .unstack (a , axis = 2 ),
513
+ np .unstack (a , axis = - 1 )]:
514
+ assert isinstance (stacks , tuple )
515
+ assert len (stacks ) == 4
516
+ assert_array_equal (stacks [0 ], a [:, :, 0 ])
517
+ assert_array_equal (stacks [1 ], a [:, :, 1 ])
518
+ assert_array_equal (stacks [2 ], a [:, :, 2 ])
519
+ assert_array_equal (stacks [3 ], a [:, :, 3 ])
520
+
521
+ assert_raises (ValueError , np .unstack , a , axis = 3 )
522
+ assert_raises (ValueError , np .unstack , a , axis = - 4 )
523
+
524
+
493
525
@pytest .mark .parametrize ("axis" , [0 ])
494
526
@pytest .mark .parametrize ("out_dtype" , ["c8" , "f4" , "f8" , ">f8" , "i8" ])
495
527
@pytest .mark .parametrize ("casting" ,
0 commit comments