Skip to content

Commit 5338003

Browse files
committed
support for onehot (rank=1 for now)
1 parent 8e025e3 commit 5338003

File tree

3 files changed

+28
-18
lines changed

3 files changed

+28
-18
lines changed

tests/test_backend.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,12 +803,14 @@ def test_cast(self):
803803
def test_onehot0(self):
804804
# no such op in onnx
805805
x_val = np.array([0, 1, 2], dtype=np.int32)
806-
depth = 3
807-
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
808-
x_ = tf.one_hot(x, depth, on_value=5.0, axis=0, off_value=1.0, dtype=tf.float32)
809-
output = tf.identity(x_, name=_TFOUTPUT)
810-
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
811-
self.assertAllClose(expected, actual)
806+
depth = 5
807+
for axis in [-1, 0, 1]:
808+
tf.reset_default_graph()
809+
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
810+
x_ = tf.one_hot(x, depth, on_value=5.0, axis=axis, off_value=1.0, dtype=tf.float32)
811+
output = tf.identity(x_, name=_TFOUTPUT)
812+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
813+
self.assertAllClose(expected, actual)
812814

813815
@unittest.skip
814816
def test_onehot1(self):

tests/unity.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@ BananaRL:
4242
- value_estimate:0
4343

4444
Basic:
45-
# needs: onehot
46-
disabled: true
4745
url: https://github.com/Unity-Technologies/ml-agents/raw/master/unity-environment/Assets/ML-Agents/Examples/Basic/TFModels/Basic.bytes
4846
model: Basic.bytes
4947
input_get: get_random
5048
inputs:
51-
"vector_observation:0": [1, 1]
49+
"vector_observation:0": [10, 1]
5250
outputs:
5351
- action:0
5452
- action_probs:0

tf2onnx/tfonnx.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -890,29 +890,39 @@ def unpack_op(ctx, node, name, args):
890890

891891
def onehot_op(ctx, node, name, args):
892892
# until there is no onehot op in onnx, a workaround using gather from eye
893-
data = node.input[0]
894-
shape = ctx.get_shape(data)
895-
shapeo = ctx.get_shape(node.output[0])
896-
if len(shape) != 1:
893+
indices_name = node.input[0]
894+
indices_shape = ctx.get_shape(indices_name)
895+
if len(indices_shape) != 1:
897896
# TODO: this works for rank=1 but tensorflow supports more than this.
898897
# Same principle should work but we need to implemtn our own eye.
899898
raise ValueError("onehot op: only rank1 is supported")
900899
axis = node.get_attr("axis")
901-
node.set_attr("axis", axis.i)
900+
# axis becomes axis for gather
902901
node.set_attr("axis", 0)
903902
depth = node.inputs[1].get_tensor_value()[0]
904903
on = node.inputs[2].get_tensor_value()[0]
905904
off = node.inputs[3].get_tensor_value()[0]
906905
dtype = node.inputs[2].get_tensor_type()
907-
del node.input[:]
908-
eye = np.eye(depth, dtype=dtype) * on
909-
if off != 0:
906+
eye = np.eye(depth, dtype=dtype)
907+
if on != 0:
908+
eye[eye == 1] = on
910909
eye[eye == 0] = off
910+
else:
911+
eye[eye == 0] = off
912+
eye[eye == 1] = on
911913
const_name = utils.make_name(node.name)
912914
ctx.make_const(const_name, "Const", eye)
915+
# setup gather inputs
916+
del node.input[:]
913917
node.input.append(const_name)
914-
node.input.append(data)
918+
node.input.append(indices_name)
915919
node.type = "Gather"
920+
if axis.i == 0:
921+
# TODO: revisit for rank > 1
922+
name = utils.make_name(node.name)
923+
transpose_op = ctx.insert_new_node_on_output("Transpose", node.output[0], name)
924+
ctx.copy_shape(node.output[0], transpose_op.output[0])
925+
return [node, transpose_op]
916926
return node
917927

918928

0 commit comments

Comments
 (0)