@@ -1128,7 +1128,7 @@ def test_onehot0(self):
1128
1128
_ = tf .identity (x_ , name = _TFOUTPUT )
1129
1129
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1130
1130
1131
- @unittest .skip ("" )
1131
+ @unittest .skip ("only rank 1 is currently implemented " )
1132
1132
def test_onehot1 (self ):
1133
1133
# only rank 1 is currently implemented
1134
1134
x_val = np .array ([[0 , 2 ], [1 , - 1 ]], dtype = np .int32 )
@@ -1139,12 +1139,40 @@ def test_onehot1(self):
1139
1139
self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1140
1140
1141
1141
def test_onehot2 (self ):
1142
- x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
1143
- depth = 20
1144
- x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
1145
- x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = - 1 , off_value = 1.0 , dtype = tf .float32 )
1146
- _ = tf .identity (x_ , name = _TFOUTPUT )
1147
- self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1142
+ for axis in [- 1 , 0 , 1 ]:
1143
+ tf .reset_default_graph ()
1144
+ x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np .int32 )
1145
+ depth = 20
1146
+ x = tf .placeholder (tf .int32 , x_val .shape , name = _TFINPUT )
1147
+ x_ = tf .one_hot (x , depth , on_value = 5.0 , axis = axis , off_value = 1.0 , dtype = tf .float32 )
1148
+ _ = tf .identity (x_ , name = _TFOUTPUT )
1149
+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1150
+
1151
+ @check_opset_min_version (9 , "onehot" )
1152
+ def test_onehot3 (self ):
1153
+ # rank 1
1154
+ for np_dtype , tf_dtype in zip ([np .int32 , np .int64 ], [tf .int32 , tf .int64 ]):
1155
+ tf .reset_default_graph ()
1156
+ x_val = np .array ([0 , 1 , 2 , 1 , 2 , 0 , 1 , 2 , 1 , 2 ], dtype = np_dtype )
1157
+ depth = np .array (20 ).astype (np .int64 )
1158
+ x = tf .placeholder (tf_dtype , x_val .shape , name = _TFINPUT )
1159
+ on_off = np .array ([5.6 , 1.2 ]).astype (np_dtype )
1160
+ x_ = tf .one_hot (x , depth , on_value = on_off [0 ], axis = - 1 , off_value = on_off [1 ])
1161
+ _ = tf .identity (x_ , name = _TFOUTPUT )
1162
+ graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1163
+ self .assertTrue (len (group_nodes_by_type (graph )["OneHot" ]) == 1 , "onnx onehot should be used" )
1164
+ # rank 2
1165
+ for aixs in [- 1 , 0 , 1 , 2 ]:
1166
+ for np_dtype , tf_dtype in zip ([np .int32 , np .int64 ], [tf .int32 , tf .int64 ]):
1167
+ tf .reset_default_graph ()
1168
+ x_val = np .arange (0 , 50 , dtype = np_dtype ).reshape ([- 1 , 10 ])
1169
+ depth = np .array (20 ).astype (np .int64 )
1170
+ x = tf .placeholder (tf_dtype , x_val .shape , name = _TFINPUT )
1171
+ on_off = np .array ([5.6 , 1.2 ]).astype (np_dtype )
1172
+ x_ = tf .one_hot (x , depth , on_value = on_off [0 ], axis = aixs , off_value = on_off [1 ])
1173
+ _ = tf .identity (x_ , name = _TFOUTPUT )
1174
+ graph = self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1175
+ self .assertTrue (len (group_nodes_by_type (graph )["OneHot" ]) == 1 , "onnx onehot should be used" )
1148
1176
1149
1177
@skip_caffe2_backend ("issue undefined dim 1" )
1150
1178
def test_flatten0 (self ):
0 commit comments