@@ -1290,12 +1290,12 @@ def onehot_op(ctx, node, name, args):
1290
1290
def onehot_op9 (ctx , node , name , args ):
1291
1291
# T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
1292
1292
# tf requires that dtype is same as on-value's and off-value's dtype
1293
- # in ONNX, op's schema is (input, depth, [off-value, on-value], @int axis)
1293
+ # in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
1294
1294
# onnxruntime only supports int64
1295
1295
output_dtype = ctx .get_dtype (node .input [2 ])
1296
- if output_dtype not in [onnx_pb .TensorProto .INT64 , onnx_pb .TensorProto .INT32 ]:
1296
+ if ctx . is_target ( TARGET_RS6 ) and output_dtype not in [onnx_pb .TensorProto .INT64 , onnx_pb .TensorProto .INT32 ]:
1297
1297
log .warning ("unsupported dtype in onnxruntime, onehot can't be used directly" )
1298
- return onehot_op (ctx , node , name , args )
1298
+ onehot_op (ctx , node , name , args )
1299
1299
1300
1300
depth = node .input [1 ]
1301
1301
depth = ctx .make_node ("Unsqueeze" , [depth ], attr = {"axes" : [0 ]}).output [0 ]
@@ -1307,21 +1307,21 @@ def onehot_op9(ctx, node, name, args):
1307
1307
off_on_value = ctx .make_node ("Concat" , [off_value , on_value ], attr = {"axis" : 0 }).output [0 ]
1308
1308
1309
1309
indices = node .input [0 ]
1310
- if ctx .get_dtype (indices ) != onnx_pb .TensorProto .INT64 :
1310
+ if ctx .get_dtype (indices ) != onnx_pb .TensorProto .INT64 and ctx . is_target ( TARGET_RS6 ) :
1311
1311
indices = ctx .make_node ("Cast" , [indices ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1312
- node .input [0 ] = indices
1312
+ node .input [0 ] = indices
1313
1313
1314
- if ctx .get_dtype (depth ) != onnx_pb .TensorProto .INT64 :
1314
+ if ctx .get_dtype (depth ) != onnx_pb .TensorProto .INT64 and ctx . is_target ( TARGET_RS6 ) :
1315
1315
depth = ctx .make_node ("Cast" , [depth ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1316
1316
node .input [1 ] = depth
1317
1317
1318
- if output_dtype != onnx_pb .TensorProto .INT64 :
1318
+ if output_dtype != onnx_pb .TensorProto .INT64 and ctx . is_target ( TARGET_RS6 ) :
1319
1319
off_on_value = ctx .make_node ("Cast" , [off_on_value ], attr = {"to" : onnx_pb .TensorProto .INT64 }).output [0 ]
1320
1320
node .input [2 ] = off_on_value
1321
1321
1322
1322
del node .input [3 ]
1323
1323
1324
- if output_dtype != onnx_pb .TensorProto .INT64 :
1324
+ if output_dtype != onnx_pb .TensorProto .INT64 and ctx . is_target ( TARGET_RS6 ) :
1325
1325
new_output = utils .make_name ("onehot_output" )
1326
1326
ctx .insert_new_node_on_output ("Cast" , node .output [0 ], new_output , to = output_dtype )
1327
1327
0 commit comments