Skip to content

Commit 9875243

Browse files
committed
map tf.add_n to sum
1 parent 5c83131 commit 9875243

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

tests/test_backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,15 @@ def test_topk(self):
770770
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
771771
self.assertAllClose(expected, actual)
772772

773+
@unittest.skipIf(OPSET < 6, "supported since opset 6")
774+
def test_addn(self):
775+
x_val = np.arange(3*2*3).astype("float32")
776+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
777+
x_ = tf.add_n([x, x, x])
778+
output = tf.identity(x_, name=_TFOUTPUT)
779+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
780+
self.assertAllClose(expected, actual)
781+
773782
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "multiple dims not supported")
774783
def test_strided_slice1(self):
775784
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)

tf2onnx/tfonnx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,6 +876,10 @@ def topk_op(ctx, node, name, args):
876876
"ExpandDims": (expanddims_op7, []),
877877
}
878878

879+
_OPSET_6 = {
880+
"AddN": (direct_op, ["Sum"]),
881+
}
882+
879883
_OPSET_7 = {
880884
"Tile": (direct_op, []),
881885
"ResizeNearestNeighbor": (upsample_op, []),
@@ -903,6 +907,7 @@ def topk_op(ctx, node, name, args):
903907
_OPSETS = [
904908
(4, _OPSET_4),
905909
(5, _OPSET_5),
910+
(6, _OPSET_6),
906911
(7, _OPSET_7),
907912
]
908913

0 commit comments

Comments
 (0)