Skip to content

Commit 6f7c8b5

Browse files
committed
refactor code
1 parent 40ca2cd commit 6f7c8b5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

tf2onnx/tfonnx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,12 +1290,12 @@ def onehot_op(ctx, node, name, args):
12901290
def onehot_op9(ctx, node, name, args):
12911291
# T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
12921292
# tf requires that dtype is same as on-value's and off-value's dtype
1293-
# in ONNX, op's schema is (input, depth, [off-value, on-value], @int axis)
1293+
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
12941294
# onnxruntime only supports int64
12951295
output_dtype = ctx.get_dtype(node.input[2])
1296-
if output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1296+
if ctx.is_target(TARGET_RS6) and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
12971297
log.warning("unsupported dtype in onnxruntime, onehot can't be used directly")
1298-
return onehot_op(ctx, node, name, args)
1298+
onehot_op(ctx, node, name, args)
12991299

13001300
depth = node.input[1]
13011301
depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]
@@ -1307,21 +1307,21 @@ def onehot_op9(ctx, node, name, args):
13071307
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
13081308

13091309
indices = node.input[0]
1310-
if ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
1310+
if ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
13111311
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1312-
node.input[0] = indices
1312+
node.input[0] = indices
13131313

1314-
if ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
1314+
if ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
13151315
depth = ctx.make_node("Cast", [depth], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13161316
node.input[1] = depth
13171317

1318-
if output_dtype != onnx_pb.TensorProto.INT64:
1318+
if output_dtype != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
13191319
off_on_value = ctx.make_node("Cast", [off_on_value], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13201320
node.input[2] = off_on_value
13211321

13221322
del node.input[3]
13231323

1324-
if output_dtype != onnx_pb.TensorProto.INT64:
1324+
if output_dtype != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
13251325
new_output = utils.make_name("onehot_output")
13261326
ctx.insert_new_node_on_output("Cast", node.output[0], new_output, to=output_dtype)
13271327

0 commit comments

Comments
 (0)