Skip to content

Commit a4c8bbe

Browse files
committed
Adding Constant Folding for Reshape Nodes to Optimizer
1 parent b4126fb commit a4c8bbe

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

tf2onnx/optimizer/const_fold_optimizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ def _fold_transpose(node, graph) -> list:
110110
const_val_after_trans = const_val.transpose(perm)
111111
return [const_val_after_trans]
112112

113+
@staticmethod
114+
@_register_func("Reshape")
115+
def _fold_reshape(node, graph):
116+
const_val_data = node.inputs[0].get_tensor_value(as_list=False)
117+
const_val_shape = node.inputs[1].get_tensor_value(as_list=False)
118+
const_val_after_trans = const_val_data.reshape(const_val_shape)
119+
return [const_val_after_trans]
120+
113121
@staticmethod
114122
@_register_func("Unsqueeze")
115123
def _fold_unsqueeze(node, graph):

0 commit comments

Comments
 (0)