Skip to content

Commit 8e025e3

Browse files
committed
support for onehost, support for reshape on int tensors
1 parent d9d8054 commit 8e025e3

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

tests/test_backend.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def test_onehot0(self):
805805
x_val = np.array([0, 1, 2], dtype=np.int32)
806806
depth = 3
807807
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
808-
x_ = tf.one_hot(x, depth, on_value=5.0, axis=1, off_value=1.0, dtype=tf.float32)
808+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=0, off_value=1.0, dtype=tf.float32)
809809
output = tf.identity(x_, name=_TFOUTPUT)
810810
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
811811
self.assertAllClose(expected, actual)
@@ -821,6 +821,16 @@ def test_onehot1(self):
821821
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
822822
self.assertAllClose(expected, actual)
823823

824+
def test_onehot2(self):
825+
# no such op in onnx
826+
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
827+
depth = 20
828+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
829+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=-1, off_value=1.0, dtype=tf.float32)
830+
output = tf.identity(x_, name=_TFOUTPUT)
831+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
832+
self.assertAllClose(expected, actual)
833+
824834
@unittest.skipIf(BACKEND in ["caffe2"], "issue undefined dim 1")
825835
def test_flatten0(self):
826836
x_val = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], dtype=np.float32)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,12 +892,14 @@ def onehot_op(ctx, node, name, args):
892892
# until there is no onehot op in onnx, a workaround using gather from eye
893893
data = node.input[0]
894894
shape = ctx.get_shape(data)
895+
shapeo = ctx.get_shape(node.output[0])
895896
if len(shape) != 1:
896897
# TODO: this works for rank=1 but tensorflow supports more than this.
897898
# Same principle should work but we need to implemtn our own eye.
898899
raise ValueError("onehot op: only rank1 is supported")
899900
axis = node.get_attr("axis")
900901
node.set_attr("axis", axis.i)
902+
node.set_attr("axis", 0)
901903
depth = node.inputs[1].get_tensor_value()[0]
902904
on = node.inputs[2].get_tensor_value()[0]
903905
off = node.inputs[3].get_tensor_value()[0]

0 commit comments

Comments
 (0)