Skip to content

Commit 931117d

Browse files
committed
refactor
1 parent 6f7c8b5 commit 931117d

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

tests/test_backend.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def test_onehot0(self):
11281128
_ = tf.identity(x_, name=_TFOUTPUT)
11291129
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11301130

1131-
@unittest.skip("")
1131+
@unittest.skip("only rank 1 is currently implemented")
11321132
def test_onehot1(self):
11331133
# only rank 1 is currently implemented
11341134
x_val = np.array([[0, 2], [1, -1]], dtype=np.int32)
@@ -1139,12 +1139,14 @@ def test_onehot1(self):
11391139
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11401140

11411141
def test_onehot2(self):
1142-
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
1143-
depth = 20
1144-
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
1145-
x_ = tf.one_hot(x, depth, on_value=5.0, axis=-1, off_value=1.0, dtype=tf.float32)
1146-
_ = tf.identity(x_, name=_TFOUTPUT)
1147-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
1142+
for axis in [-1, 0, 1]:
1143+
tf.reset_default_graph()
1144+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
1145+
depth = 20
1146+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
1147+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
1148+
_ = tf.identity(x_, name=_TFOUTPUT)
1149+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11481150

11491151
@check_opset_min_version(9, "onehot")
11501152
def test_onehot3(self):
@@ -1160,16 +1162,17 @@ def test_onehot3(self):
11601162
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
11611163
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
11621164
# rank 2
1163-
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1164-
tf.reset_default_graph()
1165-
x_val = np.arange(0, 50, dtype=np_dtype).reshape([-1, 10])
1166-
depth = np.array(20).astype(np.int64)
1167-
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1168-
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1169-
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=-1, off_value=on_off[1])
1170-
_ = tf.identity(x_, name=_TFOUTPUT)
1171-
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1172-
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
1165+
for aixs in [-1, 0, 1, 2]:
1166+
for np_dtype, tf_dtype in zip([np.int32, np.int64], [tf.int32, tf.int64]):
1167+
tf.reset_default_graph()
1168+
x_val = np.arange(0, 50, dtype=np_dtype).reshape([-1, 10])
1169+
depth = np.array(20).astype(np.int64)
1170+
x = tf.placeholder(tf_dtype, x_val.shape, name=_TFINPUT)
1171+
on_off = np.array([5.6, 1.2]).astype(np_dtype)
1172+
x_ = tf.one_hot(x, depth, on_value=on_off[0], axis=aixs, off_value=on_off[1])
1173+
_ = tf.identity(x_, name=_TFOUTPUT)
1174+
graph = self._run_test_case([_OUTPUT], {_INPUT: x_val})
1175+
self.assertTrue(len(group_nodes_by_type(graph)["OneHot"]) == 1, "onnx onehot should be used")
11731176

11741177
@skip_caffe2_backend("issue undefined dim 1")
11751178
def test_flatten0(self):

tf2onnx/tfonnx.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,8 +1294,9 @@ def onehot_op9(ctx, node, name, args):
12941294
# onnxruntime only supports int64
12951295
output_dtype = ctx.get_dtype(node.input[2])
12961296
if ctx.is_target(TARGET_RS6) and output_dtype not in [onnx_pb.TensorProto.INT64, onnx_pb.TensorProto.INT32]:
1297-
log.warning("unsupported dtype in onnxruntime, onehot can't be used directly")
1297+
log.warning("unsupported dtype in onnxruntime, onehot-9 can't be used directly")
12981298
onehot_op(ctx, node, name, args)
1299+
return
12991300

13001301
depth = node.input[1]
13011302
depth = ctx.make_node("Unsqueeze", [depth], attr={"axes": [0]}).output[0]
@@ -1307,23 +1308,25 @@ def onehot_op9(ctx, node, name, args):
13071308
off_on_value = ctx.make_node("Concat", [off_value, on_value], attr={"axis": 0}).output[0]
13081309

13091310
indices = node.input[0]
1310-
if ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
1311+
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(indices) != onnx_pb.TensorProto.INT64:
13111312
indices = ctx.make_node("Cast", [indices], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13121313
node.input[0] = indices
13131314

1314-
if ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
1315+
if ctx.is_target(TARGET_RS6) and ctx.get_dtype(depth) != onnx_pb.TensorProto.INT64:
13151316
depth = ctx.make_node("Cast", [depth], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13161317
node.input[1] = depth
13171318

1318-
if output_dtype != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
1319+
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
13191320
off_on_value = ctx.make_node("Cast", [off_on_value], attr={"to": onnx_pb.TensorProto.INT64}).output[0]
13201321
node.input[2] = off_on_value
13211322

13221323
del node.input[3]
13231324

1324-
if output_dtype != onnx_pb.TensorProto.INT64 and ctx.is_target(TARGET_RS6):
1325-
new_output = utils.make_name("onehot_output")
1326-
ctx.insert_new_node_on_output("Cast", node.output[0], new_output, to=output_dtype)
1325+
if ctx.is_target(TARGET_RS6) and output_dtype != onnx_pb.TensorProto.INT64:
1326+
new_node_name = utils.make_name("onehot_output")
1327+
new_node = ctx.insert_new_node_on_output("Cast", node.output[0], new_node_name, to=output_dtype)
1328+
ctx.set_dtype(new_node.output[0], output_dtype)
1329+
ctx.set_shape(new_node.output[0], ctx.get_shape(node.output[0]))
13271330

13281331

13291332
def fused_batchnorm_op7(ctx, node, name, args):

0 commit comments

Comments
 (0)