@@ -1287,6 +1287,45 @@ def onehot_op(ctx, node, name, args):
1287
1287
return
1288
1288
1289
1289
1290
+ def onehot_op9 (ctx , node , name , args ):
1291
+ # T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
1292
+ # tf requires that dtype is same as on-value's and off-value's dtype
1293
+ # in ONNX, op's schema is (input, depth, [off-value, on-value], @int axis)
1294
+ # onnxruntime only supports int64
1295
+ output_dtype = ctx .get_dtype (node .input [2 ])
1296
+ if output_dtype not in [onnx_pb .TensorProto .INT64 , onnx_pb .TensorProto .INT32 ]:
1297
+ log .warning ("unsupported dtype in onnxruntime, onehot can't be used directly" )
1298
+ return onehot_op (ctx , node , name , args )
1299
+
1300
+ depth = node .input [1 ]
1301
+ depth = ctx .make_node ("Unsqueeze" , [depth ], attr = {"axes" : [0 ]}).output [0 ]
1302
+
1303
+ on_value = node .input [2 ]
1304
+ off_value = node .input [3 ]
1305
+ on_value = ctx .make_node ("Unsqueeze" , [on_value ], attr = {"axes" : [0 ]}).output [0 ]
1306
+ off_value = ctx .make_node ("Unsqueeze" , [off_value ], attr = {"axes" : [0 ]}).output [0 ]
1307
+ off_on_value = ctx .make_node ("Concat" , [off_value , on_value ], attr = {"axis" : 0 }).output [0 ]
1308
+
1309
+ indices = node .input [0 ]
1310
+ if ctx .get_dtype (indices ) != onnx_pb .TensorProto .INT64 :
1311
+ indices = ctx .make_node ("Cast" , [indices ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1312
+ node .input [0 ] = indices
1313
+
1314
+ if ctx .get_dtype (depth ) != onnx_pb .TensorProto .INT64 :
1315
+ depth = ctx .make_node ("Cast" , [depth ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1316
+ node .input [1 ] = depth
1317
+
1318
+ if output_dtype != onnx_pb .TensorProto .INT64 :
1319
+ off_on_value = ctx .make_node ("Cast" , [off_on_value ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1320
+ node .input [2 ] = off_on_value
1321
+
1322
+ del node .input [3 ]
1323
+
1324
+ if output_dtype != onnx_pb .TensorProto .INT64 :
1325
+ new_output = utils .make_name ("onehot_output" )
1326
+ ctx .insert_new_node_on_output ("Cast" , node .output [0 ], new_output , to = output_dtype )
1327
+
1328
+
1290
1329
def fused_batchnorm_op7 (ctx , node , name , args ):
1291
1330
node .type = "BatchNormalization"
1292
1331
# tf inputs: x, scale, bias, mean, variance
@@ -1865,6 +1904,7 @@ def where_op(ctx, node, name, args):
1865
1904
"Erf" : (direct_op , []),
1866
1905
"Fill" : (fill_op , []),
1867
1906
"IsNan" : (direct_op , ["IsNaN" ]),
1907
+ "OneHot" : (onehot_op9 , []),
1868
1908
"ResizeBilinear" : (upsample_op9 , ["Upsample" , "linear" ]),
1869
1909
"ResizeNearestNeighbor" : (upsample_op9 , ["Upsample" , "nearest" ]),
1870
1910
"ReverseSequence" : (reverse_op9 , []),
0 commit comments