Skip to content

Commit 53a1ace

Browse files
authored
Merge pull request #55 from onnx/gs/onnx-1.2
support for tf.stack
2 parents 92326e8 + 7f89fe1 commit 53a1ace

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

tests/test_backend.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,22 @@ def test_topk(self):
834834
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
835835
self.assertAllClose(expected, actual)
836836

837+
def test_stack_axis0(self):
838+
x_val = [np.random.randn(3, 4).astype("float32") for _ in range(10)]
839+
x = [tf.constant(x_val[i], dtype=tf.float32) for i in range(10)]
840+
x_ = tf.stack(x, axis=0)
841+
output = tf.identity(x_, name=_TFOUTPUT)
842+
actual, expected = self._run(output, {}, {})
843+
self.assertAllClose(expected, actual)
844+
845+
def test_stack_axis1(self):
846+
x_val = [np.random.randn(3, 4).astype("float32") for _ in range(10)]
847+
x = [tf.constant(x_val[i], dtype=tf.float32) for i in range(10)]
848+
x_ = tf.stack(x, axis=1)
849+
output = tf.identity(x_, name=_TFOUTPUT)
850+
actual, expected = self._run(output, {}, {})
851+
self.assertAllClose(expected, actual)
852+
837853
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "Space2Depth not implemented, works on onnxmsrtnext")
838854
def test_space_to_depth(self):
839855
x_val = make_xval([1, 2, 2, 1])
@@ -873,7 +889,6 @@ def test_strided_slice2(self):
873889

874890
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not correctly supported")
875891
def test_resize_nearest_neighbor(self):
876-
# this should work but no runtime I tried supports it.
877892
x_shape = [1, 15, 20, 2]
878893
x_new_size = [30, 40]
879894
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
@@ -886,7 +901,6 @@ def test_resize_nearest_neighbor(self):
886901

887902
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not correctly supported")
888903
def test_resize_bilinear(self):
889-
# this should work but no runtime I tried supports it.
890904
x_shape = [1, 15, 20, 2]
891905
x_new_size = [30, 40]
892906
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)

tf2onnx/tfonnx.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,27 @@ def minmax_op(ctx, node, name, args):
832832
return node
833833

834834

835+
def pack_op(ctx, node, name, args):
836+
# hack to make up for the missing onnx pack op
837+
axis = node.get_attr("axis").i
838+
nodes = []
839+
inputs = []
840+
# insert Unsqueeze on each input
841+
for i, n in enumerate(node.inputs):
842+
op_name = utils.make_name(node.name)
843+
output_name = op_name + ":0"
844+
new_node = Node(helper.make_node("Unsqueeze", [node.input[i]], [output_name], name=op_name, axes=[axis]), ctx)
845+
node.input[i] = output_name
846+
nodes.append(new_node)
847+
inputs.append(output_name)
848+
# concat all unqueezes
849+
op_name = utils.make_name(node.name)
850+
output_name = op_name + ":0"
851+
concat = Node(helper.make_node("Concat", inputs, [output_name], name=op_name, axis=axis), ctx)
852+
ctx.copy_shape(node.output[0], concat.output[0])
853+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], output_name)
854+
return [concat] + nodes
855+
835856
# pylint: enable=W0613,C0111,W0612
836857

837858
# map tensorflow ops to onnx ops. The format below is
@@ -917,6 +938,7 @@ def minmax_op(ctx, node, name, args):
917938
"Transpose": (transpose_op, []),
918939
"TopKV2": (topk_op, []),
919940
"SpaceToDepth": (spacetodepth_op, []),
941+
"Pack": (pack_op, []),
920942
}
921943

922944
_OPSET_5 = {

0 commit comments

Comments
 (0)