Skip to content

Commit c118261

Browse files
authored
Merge pull request #56 from onnx/gs/onnx-1.2
support for tf.unstack
2 parents 53a1ace + 88ada9d commit c118261

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

tests/test_backend.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -834,18 +834,30 @@ def test_topk(self):
834834
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
835835
self.assertAllClose(expected, actual)
836836

837-
def test_stack_axis0(self):
838-
x_val = [np.random.randn(3, 4).astype("float32") for _ in range(10)]
839-
x = [tf.constant(x_val[i], dtype=tf.float32) for i in range(10)]
840-
x_ = tf.stack(x, axis=0)
841-
output = tf.identity(x_, name=_TFOUTPUT)
842-
actual, expected = self._run(output, {}, {})
843-
self.assertAllClose(expected, actual)
844-
845-
def test_stack_axis1(self):
846-
x_val = [np.random.randn(3, 4).astype("float32") for _ in range(10)]
847-
x = [tf.constant(x_val[i], dtype=tf.float32) for i in range(10)]
848-
x_ = tf.stack(x, axis=1)
837+
def test_stack_axis(self):
838+
for axis in [0, 1]:
839+
tf.reset_default_graph()
840+
x_val = [np.random.randn(3, 4).astype("float32") for _ in range(10)]
841+
x = [tf.constant(x_val[i], dtype=tf.float32) for i in range(10)]
842+
x_ = tf.stack(x, axis=axis)
843+
output = tf.identity(x_, name=_TFOUTPUT)
844+
actual, expected = self._run(output, {}, {})
845+
self.assertAllClose(expected, actual)
846+
847+
def test_unstack_axis(self):
848+
for axis in [0, 1]:
849+
tf.reset_default_graph()
850+
x_val = np.random.randn(10, 3, 4).astype("float32")
851+
x = tf.constant(x_val, dtype=tf.float32)
852+
x_ = tf.unstack(x, axis=axis)
853+
output = tf.identity(x_, name=_TFOUTPUT)
854+
actual, expected = self._run(output, {}, {})
855+
self.assertAllClose(expected, actual)
856+
857+
def test_unstack_axis1(self):
858+
x_val = np.random.randn(10, 3, 4).astype("float32")
859+
x = tf.constant(x_val, dtype=tf.float32)
860+
x_ = tf.unstack(x, axis=1)
849861
output = tf.identity(x_, name=_TFOUTPUT)
850862
actual, expected = self._run(output, {}, {})
851863
self.assertAllClose(expected, actual)

tf2onnx/tfonnx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,23 @@ def pack_op(ctx, node, name, args):
853853
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], output_name)
854854
return [concat] + nodes
855855

856+
def unpack_op(ctx, node, name, args):
857+
# hack to make up for the missing onnx unpack op
858+
axis = node.get_attr("axis").i
859+
# split the tensor into n outputs
860+
node.type = "Split"
861+
nodes = [node]
862+
# for each output we need to squeeze axis
863+
for i, n in enumerate(node.output):
864+
op_name = utils.make_name(node.name)
865+
output_name = op_name + ":" + str(i)
866+
new_node = Node(helper.make_node("Squeeze", [n], [output_name], name=op_name, axes=[axis]), ctx)
867+
nodes.append(new_node)
868+
ctx.copy_shape(n, output_name)
869+
ctx.replace_all_inputs(ctx.get_nodes(), n, output_name)
870+
return nodes
871+
872+
856873
# pylint: enable=W0613,C0111,W0612
857874

858875
# map tensorflow ops to onnx ops. The format below is
@@ -939,6 +956,7 @@ def pack_op(ctx, node, name, args):
939956
"TopKV2": (topk_op, []),
940957
"SpaceToDepth": (spacetodepth_op, []),
941958
"Pack": (pack_op, []),
959+
"Unpack": (unpack_op, []),
942960
}
943961

944962
_OPSET_5 = {

0 commit comments

Comments
 (0)