Skip to content

Commit 97c47b7

Browse files
Add constant folding for Split op (#1553)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 080e7cf commit 97c47b7

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

tests/test_optimizers.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1958,6 +1958,87 @@ def test_const_fold_cast_with_const(self):
19581958
self.run_and_compare(["res"], {"X": np.random.randn(*shape).astype(np.int64)}, model_proto,
19591959
"Cast", 0)
19601960

1961+
def test_const_fold_split(self):
1962+
shape = (2, 6, 1)
1963+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
1964+
vals=np.random.randn(2, 6, 1).flatten().astype(np.float32))
1965+
node0 = helper.make_node("Constant", [], ["const"], value=const_tensor)
1966+
node1 = helper.make_node("Split", ["const"], ["out1", "out2", "out3"], axis=1)
1967+
node2 = helper.make_node("Sum", ["inp", "out1", "out2", "out3"], ["out4"])
1968+
1969+
graph = helper.make_graph(
1970+
[node0, node1, node2],
1971+
"test_const_fold_split",
1972+
[helper.make_tensor_value_info("inp", TensorProto.FLOAT, (2, 2, 1))],
1973+
[helper.make_tensor_value_info("out4", TensorProto.FLOAT, (2, 2, 1))],
1974+
)
1975+
1976+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1977+
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 2, 1).astype(np.float32)}, model_proto,
1978+
"Split", 0)
1979+
1980+
def test_const_fold_split_one(self):
1981+
shape = (2, 6, 1)
1982+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
1983+
vals=np.random.randn(2, 6, 1).flatten().astype(np.float32))
1984+
node0 = helper.make_node("Constant", [], ["const"], value=const_tensor)
1985+
node1 = helper.make_node("Split", ["const"], ["out1"], axis=1)
1986+
node2 = helper.make_node("Sum", ["inp", "out1"], ["out4"])
1987+
1988+
graph = helper.make_graph(
1989+
[node0, node1, node2],
1990+
"test_const_fold_split",
1991+
[helper.make_tensor_value_info("inp", TensorProto.FLOAT, (2, 6, 1))],
1992+
[helper.make_tensor_value_info("out4", TensorProto.FLOAT, (2, 6, 1))],
1993+
)
1994+
1995+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1996+
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 6, 1).astype(np.float32)}, model_proto,
1997+
"Split", 0)
1998+
1999+
@check_opset_min_version(13, "Split changed in opset 13")
2000+
def test_const_fold_split_const_splits_13(self):
2001+
shape = (2, 6, 1)
2002+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2003+
vals=np.random.randn(2, 6, 1).flatten().astype(np.float32))
2004+
node0 = helper.make_node("Constant", [], ["const"], value=const_tensor)
2005+
const_splits = helper.make_tensor(name='const_tensor', data_type=TensorProto.INT64, dims=[3],
2006+
vals=np.array([1, 3, 2], np.int64))
2007+
node1 = helper.make_node("Constant", [], ["splits"], value=const_splits)
2008+
node2 = helper.make_node("Split", ["const", "splits"], ["out1", "out2", "out3"], axis=1)
2009+
node3 = helper.make_node("Sum", ["inp", "out2"], ["out4"])
2010+
2011+
graph = helper.make_graph(
2012+
[node0, node1, node2, node3],
2013+
"test_const_fold_split",
2014+
[helper.make_tensor_value_info("inp", TensorProto.FLOAT, (2, 3, 1))],
2015+
[helper.make_tensor_value_info("out4", TensorProto.FLOAT, (2, 3, 1))],
2016+
)
2017+
2018+
model_proto = self.make_model(graph, producer_name="onnx-tests")
2019+
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto,
2020+
"Split", 0)
2021+
2022+
@check_opset_max_version(12, "Split changed in opset 13")
2023+
def test_const_fold_split_const_splits(self):
2024+
shape = (2, 6, 1)
2025+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
2026+
vals=np.random.randn(2, 6, 1).flatten().astype(np.float32))
2027+
node0 = helper.make_node("Constant", [], ["const"], value=const_tensor)
2028+
node2 = helper.make_node("Split", ["const"], ["out1", "out2", "out3"], axis=1, split=[1, 3, 2])
2029+
node3 = helper.make_node("Sum", ["inp", "out2"], ["out4"])
2030+
2031+
graph = helper.make_graph(
2032+
[node0, node2, node3],
2033+
"test_const_fold_split",
2034+
[helper.make_tensor_value_info("inp", TensorProto.FLOAT, (2, 3, 1))],
2035+
[helper.make_tensor_value_info("out4", TensorProto.FLOAT, (2, 3, 1))],
2036+
)
2037+
2038+
model_proto = self.make_model(graph, producer_name="onnx-tests")
2039+
self.run_and_compare(["out4"], {"inp": np.random.randn(2, 3, 1).astype(np.float32)}, model_proto,
2040+
"Split", 0)
2041+
19612042
# Const Fold Optimizer Tests End
19622043

19632044
# Const Dequantize Optimizer Tests Start

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
for example, input of transpose node is const then we can do transpose statically instead of at runtime
77
"""
88

9+
import numpy as np
910
from .. import utils
1011
from .optimizer_base import GraphOptimizerBase
1112

@@ -153,3 +154,19 @@ def _fold_unsqueeze(node, graph):
153154

154155
const_val_after_unsqueeze = const_val.reshape(shape_out)
155156
return [const_val_after_unsqueeze]
157+
158+
@staticmethod
159+
@_register_func("Split")
160+
def _fold_split(node, graph):
161+
data = node.inputs[0].get_tensor_value(as_list=False)
162+
axis = node.get_attr_value('axis', 0)
163+
if len(node.output) == 1:
164+
return [data]
165+
split = node.get_attr_value('split')
166+
if len(node.input) > 1:
167+
split = node.inputs[1].get_tensor_value(as_list=False)
168+
if split is not None:
169+
indices_or_sections = np.cumsum(split[:-1])
170+
else:
171+
indices_or_sections = len(node.output)
172+
return np.split(data, indices_or_sections, axis)

0 commit comments

Comments
 (0)