@@ -1128,7 +1128,7 @@ def test_onehot0(self):
1128
1128
_ = tf .identity (x_ , name = _TFOUTPUT )
1129
1129
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1130
1130
1131
- @unittest .skip ("" )
1131
+ @unittest .skip ("only rank 1 is currently implemented " )
1132
1132
def test_onehot1 (self ):
1133
1133
# only rank 1 is currently implemented
1134
1134
x_val = np .array ([[0 , 2 ], [1 , - 1 ]], dtype = np .int32 )
@@ -1139,12 +1139,14 @@ def test_onehot1(self):
1139
1139
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1140
1140
1141
1141
def test_onehot2 (self ):
1142
- x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
1143
- depth = 20
1144
- x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
1145
- x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = - 1 , off_value = 1.0 , dtype = tf .float32 )
1146
- _ = tf .identity (x_ , name = _TFOUTPUT )
1147
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1142
+ for axis in [- 1 , 0 , 1 ]:
1143
+ tf .reset_default_graph ()
1144
+ x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
1145
+ depth = 20
1146
+ x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
1147
+ x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = axis , off_value = 1.0 , dtype = tf .float32 )
1148
+ _ = tf .identity (x_ , name = _TFOUTPUT )
1149
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1148
1150
1149
1151
@check_opset_min_version (9 , "onehot" )
1150
1152
def test_onehot3 (self ):
@@ -1160,16 +1162,17 @@ def test_onehot3(self):
1160
1162
graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1161
1163
self .assertTrue (len (group_nodes_by_type (graph )["OneHot" ]) == 1 , "onnx onehot should be used" )
1162
1164
# rank 2
1163
- for np_dtype , tf_dtype in zip ([np .int32 , np .int64 ], [tf .int32 , tf .int64 ]):
1164
- tf .reset_default_graph ()
1165
- x_val = np .arange (0 , 50 , dtype = np_dtype ).reshape ([- 1 , 10 ])
1166
- depth = np .array (20 ).astype (np .int64 )
1167
- x = tf .placeholder (tf_dtype , x_val .shape , name = _TFINPUT )
1168
- on_off = np .array ([5.6 , 1.2 ]).astype (np_dtype )
1169
- x_ = tf .one_hot (x , depth , on_value = on_off [0 ], axis = - 1 , off_value = on_off [1 ])
1170
- _ = tf .identity (x_ , name = _TFOUTPUT )
1171
- graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1172
- self .assertTrue (len (group_nodes_by_type (graph )["OneHot" ]) == 1 , "onnx onehot should be used" )
1165
+ for aixs in [- 1 , 0 , 1 , 2 ]:
1166
+ for np_dtype , tf_dtype in zip ([np .int32 , np .int64 ], [tf .int32 , tf .int64 ]):
1167
+ tf .reset_default_graph ()
1168
+ x_val = np .arange (0 , 50 , dtype = np_dtype ).reshape ([- 1 , 10 ])
1169
+ depth = np .array (20 ).astype (np .int64 )
1170
+ x = tf .placeholder (tf_dtype , x_val .shape , name = _TFINPUT )
1171
+ on_off = np .array ([5.6 , 1.2 ]).astype (np_dtype )
1172
+ x_ = tf .one_hot (x , depth , on_value = on_off [0 ], axis = aixs , off_value = on_off [1 ])
1173
+ _ = tf .identity (x_ , name = _TFOUTPUT )
1174
+ graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1175
+ self .assertTrue (len (group_nodes_by_type (graph )["OneHot" ]) == 1 , "onnx onehot should be used" )
1173
1176
1174
1177
@skip_caffe2_backend ("issue undefined dim 1" )
1175
1178
def test_flatten0 (self ):
0 commit comments