Skip to content

Commit f8a6a86

Browse files
Const fold Concat (#1554)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 97c47b7 commit f8a6a86

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

tests/test_optimizers.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,28 @@ def test_const_fold_node_is_output(self):
18981898
self.run_transpose_compare(["res"], {},
18991899
model_proto, remaining_transpose_num=0)
19001900

1901+
def test_const_fold_concat(self):
1902+
shape = (6, 4)
1903+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
1904+
vals=np.random.randn(*shape).flatten().astype(np.float32))
1905+
const_tensor2 = helper.make_tensor(name='const_tensor2', data_type=TensorProto.FLOAT, dims=shape,
1906+
vals=np.random.randn(*shape).flatten().astype(np.float32))
1907+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
1908+
node2 = helper.make_node("Constant", [], ["const2"], value=const_tensor2)
1909+
node3 = helper.make_node("Concat", ["const", "const2", "const"], ["value1"], axis=1)
1910+
node4 = helper.make_node("Add", ["value1", "inp"], ["res"])
1911+
1912+
graph = helper.make_graph(
1913+
[node1, node2, node3, node4],
1914+
"test_const_fold_trans_with_const2",
1915+
[helper.make_tensor_value_info("inp", TensorProto.FLOAT, [6, 12])],
1916+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, [6, 12])],
1917+
)
1918+
1919+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1920+
self.run_and_compare(["res"], {"inp": np.random.randn(6, 12).astype(np.float32)}, model_proto,
1921+
"Concat", 0)
1922+
19011923
@check_opset_max_version(12, "Squeeze/Unsqueeze changed in opset 13")
19021924
def test_const_fold_unsqueeze_with_const(self):
19031925
shape = (6, 6)

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,13 @@ def _fold_reshape(node, graph):
128128
const_val_after_trans = const_val_data.reshape(const_val_shape)
129129
return [const_val_after_trans]
130130

131+
@staticmethod
132+
@_register_func("Concat")
133+
def _fold_concat(node, graph):
134+
axis = node.get_attr_value('axis')
135+
res = np.concatenate([inp.get_tensor_value(as_list=False) for inp in node.inputs], axis)
136+
return [res]
137+
131138
@staticmethod
132139
@_register_func("Unsqueeze")
133140
def _fold_unsqueeze(node, graph):

0 commit comments

Comments
 (0)