Skip to content

Commit 6ab3f7c

Browse files
committed
enhance const_fold_optimizer: supports to fold const with Unsqueeze
1 parent a27312a commit 6ab3f7c

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

tests/test_optimizers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,25 @@ def test_const_fold_node_is_output(self):
484484
model_proto = helper.make_model(graph, producer_name="onnx-tests")
485485
self.run_transpose_compare(["res"], {},
486486
model_proto, remaining_transpose_num=0)
487+
488+
def test_const_fold_unsqueeze_with_const(self):
489+
shape = (6, 6)
490+
const_tensor = helper.make_tensor(name='const_tensor', data_type=TensorProto.FLOAT, dims=shape,
491+
vals=np.random.randn(*shape).flatten().astype(np.float32))
492+
node1 = helper.make_node("Constant", [], ["const"], value=const_tensor)
493+
node2 = helper.make_node("Unsqueeze", ["const"], ["value1"], axes=[0, 2, 3])
494+
node3 = helper.make_node("Add", ["value1", "X"], ["res"])
495+
496+
graph = helper.make_graph(
497+
[node1, node2, node3],
498+
"test_const_fold_unsqueeze_with_const",
499+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1,))],
500+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 6, 1, 1, 6))],
501+
)
502+
503+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
504+
self.run_and_compare(["res"], {"X": np.random.randn(1).astype(np.float32)}, model_proto,
505+
"Unsqueeze", 0)
487506
# Const Fold Optimizer Tests End
488507

489508

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,27 @@ def _fold_transpose(node, graph) -> list:
100100
perm = perm_attr.ints if perm_attr else None
101101
const_val_after_trans = const_val.transpose(perm)
102102
return [const_val_after_trans]
103+
104+
@staticmethod
105+
@_register_func("Unsqueeze")
106+
def _fold_unsqueeze(node, graph):
107+
"""
108+
numpy expand_dims only supports to unsqueeze one dim one time, so reshape is used to simplify the logic
109+
"""
110+
const_val = node.inputs[0].get_tensor_value(as_list=False)
111+
axes = list(node.get_attr("axes").ints)
112+
utils.make_sure(all(axis >= 0 for axis in axes), "onnx spec says it only supports positive axis")
113+
shape_in = const_val.shape
114+
dims_out = len(shape_in) + len(axes)
115+
# calculate the shape of output accroding to onnx Unsqueeze's spec
116+
# https://github.com/onnx/onnx/blob/master/docs/Operators.md#Unsqueeze
117+
shape_in = iter(shape_in)
118+
shape_out = [None] * dims_out
119+
for ind in axes:
120+
shape_out[ind] = 1
121+
for ind, val in enumerate(shape_out):
122+
if val is None:
123+
shape_out[ind] = next(shape_in)
124+
125+
const_val_after_unsqueeze = const_val.reshape(shape_out)
126+
return [const_val_after_unsqueeze]

0 commit comments

Comments
 (0)