@@ -834,18 +834,30 @@ def test_topk(self):
834
834
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
835
835
self .assertAllClose (expected , actual )
836
836
837
- def test_stack_axis0 (self ):
838
- x_val = [np .random .randn (3 , 4 ).astype ("float32" ) for _ in range (10 )]
839
- x = [tf .constant (x_val [i ], dtype = tf .float32 ) for i in range (10 )]
840
- x_ = tf .stack (x , axis = 0 )
841
- output = tf .identity (x_ , name = _TFOUTPUT )
842
- actual , expected = self ._run (output , {}, {})
843
- self .assertAllClose (expected , actual )
844
-
845
- def test_stack_axis1 (self ):
846
- x_val = [np .random .randn (3 , 4 ).astype ("float32" ) for _ in range (10 )]
847
- x = [tf .constant (x_val [i ], dtype = tf .float32 ) for i in range (10 )]
848
- x_ = tf .stack (x , axis = 1 )
837
+ def test_stack_axis (self ):
838
+ for axis in [0 , 1 ]:
839
+ tf .reset_default_graph ()
840
+ x_val = [np .random .randn (3 , 4 ).astype ("float32" ) for _ in range (10 )]
841
+ x = [tf .constant (x_val [i ], dtype = tf .float32 ) for i in range (10 )]
842
+ x_ = tf .stack (x , axis = axis )
843
+ output = tf .identity (x_ , name = _TFOUTPUT )
844
+ actual , expected = self ._run (output , {}, {})
845
+ self .assertAllClose (expected , actual )
846
+
847
+ def test_unstack_axis (self ):
848
+ for axis in [0 , 1 ]:
849
+ tf .reset_default_graph ()
850
+ x_val = np .random .randn (10 , 3 , 4 ).astype ("float32" )
851
+ x = tf .constant (x_val , dtype = tf .float32 )
852
+ x_ = tf .unstack (x , axis = axis )
853
+ output = tf .identity (x_ , name = _TFOUTPUT )
854
+ actual , expected = self ._run (output , {}, {})
855
+ self .assertAllClose (expected , actual )
856
+
857
+ def test_unstack_axis1 (self ):
858
+ x_val = np .random .randn (10 , 3 , 4 ).astype ("float32" )
859
+ x = tf .constant (x_val , dtype = tf .float32 )
860
+ x_ = tf .unstack (x , axis = 1 )
849
861
output = tf .identity (x_ , name = _TFOUTPUT )
850
862
actual , expected = self ._run (output , {}, {})
851
863
self .assertAllClose (expected , actual )
0 commit comments