@@ -805,7 +805,7 @@ def test_onehot0(self):
805
805
x_val = np .array ([0 , 1 , 2 ], dtype = np .int32 )
806
806
depth = 3
807
807
x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
808
- x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = 1 , off_value = 1.0 , dtype = tf .float32 )
808
+ x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = 0 , off_value = 1.0 , dtype = tf .float32 )
809
809
output = tf .identity (x_ , name = _TFOUTPUT )
810
810
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
811
811
self .assertAllClose (expected , actual )
@@ -821,6 +821,16 @@ def test_onehot1(self):
821
821
actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
822
822
self .assertAllClose (expected , actual )
823
823
824
+ def test_onehot2 (self ):
825
+ # no such op in onnx
826
+ x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
827
+ depth = 20
828
+ x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
829
+ x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = - 1 , off_value = 1.0 , dtype = tf .float32 )
830
+ output = tf .identity (x_ , name = _TFOUTPUT )
831
+ actual , expected = self ._run (output , {x : x_val }, {_INPUT : x_val })
832
+ self .assertAllClose (expected , actual )
833
+
824
834
@unittest .skipIf (BACKEND in ["caffe2" ], "issue undefined dim 1" )
825
835
def test_flatten0 (self ):
826
836
x_val = np .array ([[[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]]], dtype = np .float32 )
0 commit comments