Skip to content

Commit bdd7617

Browse files
Fix conversion of OneHot for dtypes unsupported by ORT (#1675)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 0083dfa commit bdd7617

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

tests/test_backend.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2271,13 +2271,15 @@ def func(x):
22712271
return tf.identity(x_, name=_TFOUTPUT)
22722272
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
22732273

2274-
@check_target("rs6", "onehot")
2274+
@skip_tfjs("tfjs produces incorrect results")
22752275
def test_onehot0(self):
22762276
x_val = np.array([0, 1, 2], dtype=np.int32)
22772277
depth = 5
2278-
for axis in [-1, 0, 1]:
2278+
for dtype, axis in [(tf.float32, -1), (tf.int64, 0), (tf.float64, 1)]:
22792279
def func(x):
2280-
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
2280+
val1 = tf.constant(5, dtype)
2281+
val2 = tf.constant(1, dtype)
2282+
x_ = tf.one_hot(x, depth, on_value=val1, axis=axis, off_value=val2, dtype=dtype)
22812283
return tf.identity(x_, name=_TFOUTPUT)
22822284
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
22832285

tf2onnx/onnx_opset/tensor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,8 +1313,11 @@ def version_1(cls, ctx, node, **kwargs):
13131313
if axis.i == 0:
13141314
# TODO: revisit for rank > 1
13151315
name = utils.make_name(node.name)
1316-
transpose_node = ctx.insert_new_node_on_output("Transpose", node.output[0], name)
1317-
ctx.copy_shape(node.output[0], transpose_node.output[0])
1316+
shape = ctx.get_shape(node.output[0])
1317+
transpose_node = ctx.make_node("Transpose", [node.output[0]], name=name, shapes=[shape])
1318+
ctx.insert_node_on_output(transpose_node, node.output[0])
1319+
if shape is not None:
1320+
ctx.set_shape(node.output[0], shape[::-1])
13181321

13191322
@classmethod
13201323
def any_version_after9(cls, opset, ctx, node, **kwargs):
@@ -1323,9 +1326,11 @@ def any_version_after9(cls, opset, ctx, node, **kwargs):
13231326
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
13241327
# onnxruntime only supports int64
13251328
output_dtype = ctx.get_dtype(node.input[2])
1326-
if ctx.is_target(constants.TARGET_RS6) \
1327-
and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1328-
logger.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
1329+
supported_dtypes = [onnx_pb.TensorProto.FLOAT]
1330+
if ctx.is_target(constants.TARGET_RS6):
1331+
supported_dtypes = [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]
1332+
if output_dtype not in supported_dtypes:
1333+
logger.warning("unsupported dtype in target runtime, OneHot op can't be used directly")
13291334
cls.version_1(ctx, node, **kwargs)
13301335
return
13311336

0 commit comments

Comments
 (0)