Skip to content

Commit b2d3f97

Browse files
authored
Merge pull request #399 from zhijxu-MS/use_onehot_to_map_onehot
use onnx "onehot" to map tf "onehot" if possible
2 parents 62d7d8d + 931117d commit b2d3f97

File tree

2 files changed

+78
-7
lines changed

2 files changed

+78
-7
lines changed

tests/test_backend.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def test_onehot0(self):
11281128
_ = tf.identity(x_, name=_TFOUTPUT)
11291129
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11301130

1131-
@unittest.skip("")
1131+
@unittest.skip("only rank 1 is currently implemented")
11321132
def test_onehot1(self):
11331133
# only rank 1 is currently implemented
11341134
x_val = np.array([[0, 2], [1, -1]], dtype=np.int32)
@@ -1139,12 +1139,40 @@ def test_onehot1(self):
11391139
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11401140

11411141
def test_onehot2(self):
1142-
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
1143-
depth = 20
1144-
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
1145-
x_ = tf.one_hot(x, depth, on_value=5.0, axis=-1, off_value=1.0, dtype=tf.float32)
1146-
_ = tf.identity(x_, name=_TFOUTPUT)
1147-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1142+
for axis in [-1, 0, 1]:
1143+
tf.reset_default_graph()
1144+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
1145+
depth = 20
1146+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
1147+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
1148+
_ = tf.identity(x_, name=_TFOUTPUT)
1149+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1150+
1151+
@check_opset_min_version(9, "onehot")
1152+
def test_onehot3(self):
1153+
# rank 1
1154+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1155+
tf.reset_default_graph()
1156+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np_dtype)
1157+
depth = np.array(20).astype(np.int64)
1158+
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1159+
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1160+
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=-1, off_value=on_off[1])
1161+
_ = tf.identity(x_, name=_TFOUTPUT)
1162+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1163+
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
1164+
# rank 2
1165+
for aixs in [-1, 0, 1, 2]:
1166+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1167+
tf.reset_default_graph()
1168+
x_val = np.arange(0, 50, dtype=np_dtype).reshape([-1, 10])
1169+
depth = np.array(20).astype(np.int64)
1170+
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1171+
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1172+
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=aixs, off_value=on_off[1])
1173+
_ = tf.identity(x_, name=_TFOUTPUT)
1174+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1175+
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
11481176

11491177
@skip_caffe2_backend("issue undefined dim 1")
11501178
def test_flatten0(self):

tf2onnx/tfonnx.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,6 +1287,48 @@ def onehot_op(ctx, node, name, args):
12871287
return
12881288

12891289

1290+
def onehot_op9(ctx, node, name, args):
1291+
# T output = OneHot(uint8/int32/int64 input, T depth, T on-value, T off-value, @int axis, @dtype)
1292+
# tf requires that dtype is same as on-value's and off-value's dtype
1293+
# in ONNX, op's schema is (input, depth, value, @int axis), meaning of "value" is [off-value, on-value]
1294+
# onnxruntime only supports int64
1295+
output_dtype = ctx.get_dtype(node.input[2])
1296+
if ctx.is_target(TARGET_RS6) and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1297+
log.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
1298+
onehot_op(ctx, node, name, args)
1299+
return
1300+
1301+
depth = node.input[1]
1302+
depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]
1303+
1304+
on_value = node.input[2]
1305+
off_value = node.input[3]
1306+
on_value = ctx.make_node("Unsqueeze", [on_value], attr={"axes": [0]}).output[0]
1307+
off_value = ctx.make_node("Unsqueeze", [off_value], attr={"axes": [0]}).output[0]
1308+
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
1309+
1310+
indices = node.input[0]
1311+
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
1312+
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1313+
node.input[0] = indices
1314+
1315+
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
1316+
depth = ctx.make_node("Cast", [depth], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1317+
node.input[1] = depth
1318+
1319+
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
1320+
off_on_value = ctx.make_node("Cast", [off_on_value], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
1321+
node.input[2] = off_on_value
1322+
1323+
del node.input[3]
1324+
1325+
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
1326+
new_node_name = utils.make_name("onehot_output")
1327+
new_node = ctx.insert_new_node_on_output("Cast", node.output[0], new_node_name, to=output_dtype)
1328+
ctx.set_dtype(new_node.output[0], output_dtype)
1329+
ctx.set_shape(new_node.output[0], ctx.get_shape(node.output[0]))
1330+
1331+
12901332
def fused_batchnorm_op7(ctx, node, name, args):
12911333
node.type = "BatchNormalization"
12921334
# tf inputs: x, scale, bias, mean, variance
@@ -1866,6 +1908,7 @@ def where_op(ctx, node, name, args):
18661908
"Erf": (direct_op, []),
18671909
"Fill": (fill_op, []),
18681910
"IsNan": (direct_op, ["IsNaN"]),
1911+
"OneHot": (onehot_op9, []),
18691912
"ResizeBilinear": (upsample_op9, ["Upsample", "linear"]),
18701913
"ResizeNearestNeighbor": (upsample_op9, ["Upsample", "nearest"]),
18711914
"ReverseSequence": (reverse_op9, []),

0 commit comments

Comments
 (0)