Skip to content

Commit 40ca2cd

Browse files
committed
use onnx "onehot" to map tf "onehot" if possible
1 parent f2dc6d8 commit 40ca2cd

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

tests/test_backend.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,31 @@ def test_onehot2(self):
11461146
_ = tf.identity(x_, name=_TFOUTPUT)
11471147
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11481148

1149+
@check_opset_min_version(9, "onehot")
1150+
def test_onehot3(self):
1151+
# rank 1
1152+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1153+
tf.reset_default_graph()
1154+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np_dtype)
1155+
depth = np.array(20).astype(np.int64)
1156+
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1157+
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1158+
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=-1, off_value=on_off[1])
1159+
_ = tf.identity(x_, name=_TFOUTPUT)
1160+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1161+
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
1162+
# rank 2
1163+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1164+
tf.reset_default_graph()
1165+
x_val = np.arange(0, 50, dtype=np_dtype).reshape([-1, 10])
1166+
depth = np.array(20).astype(np.int64)
1167+
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1168+
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1169+
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=-1, off_value=on_off[1])
1170+
_ = tf.identity(x_, name=_TFOUTPUT)
1171+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1172+
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
1173+
11491174
@skip_caffe2_backend("issue undefined dim 1")
11501175
def test_flatten0(self):
11511176
x_val = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=np.float32)

tf2onnx/tfonnx.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,45 @@ def onehot_op(ctx, node, name, args):
12871287
return
12881288

12891289

1290+
def onehot_op9(ctx, node, name, args):
1291+
# T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
1292+
# tf requires that dtype is same as on-value's and off-value's dtype
1293+
# in ONNX, op's schema is (input, depth, [off-value, on-value], @int axis)
1294+
# onnxruntime only supports int64
1295+
output_dtype = ctx.get_dtype(node.input[2])
1296+
if output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1297+
log.warning("unsupported dtype in onnxruntime, onehot can't be used directly")
1298+
return onehot_op(ctx, node, name, args)
1299+
1300+
depth = node.input[1]
1301+
depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]
1302+
1303+
on_value = node.input[2]
1304+
off_value = node.input[3]
1305+
on_value = ctx.make_node("Unsqueeze", [on_value], attr={"axes": [0]}).output[0]
1306+
off_value = ctx.make_node("Unsqueeze", [off_value], attr={"axes": [0]}).output[0]
1307+
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
1308+
1309+
indices = node.input[0]
1310+
if ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
1311+
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1312+
node.input[0] = indices
1313+
1314+
if ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
1315+
depth = ctx.make_node("Cast", [depth], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1316+
node.input[1] = depth
1317+
1318+
if output_dtype != onnx_pb.TensorProto.INT64:
1319+
off_on_value = ctx.make_node("Cast", [off_on_value], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1320+
node.input[2] = off_on_value
1321+
1322+
del node.input[3]
1323+
1324+
if output_dtype != onnx_pb.TensorProto.INT64:
1325+
new_output = utils.make_name("onehot_output")
1326+
ctx.insert_new_node_on_output("Cast", node.output[0], new_output, to=output_dtype)
1327+
1328+
12901329
def fused_batchnorm_op7(ctx, node, name, args):
12911330
node.type = "BatchNormalization"
12921331
# tf inputs: x, scale, bias, mean, variance
@@ -1865,6 +1904,7 @@ def where_op(ctx, node, name, args):
18651904
"Erf": (direct_op, []),
18661905
"Fill": (fill_op, []),
18671906
"IsNan": (direct_op, ["IsNaN"]),
1907+
"OneHot": (onehot_op9, []),
18681908
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18691909
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
18701910
"ReverseSequence": (reverse_op9, []),

0 commit comments

Comments
 (0)