Skip to content

Commit 076b436

Browse files
committed
detects obious failure in ReshapeOptimizer and bypass it
Signed-off-by: xavier dupré <[email protected]>
1 parent 720cfb7 commit 076b436

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tf2onnx/optimizer/reshape_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,13 @@ def get_reshape_dim(val, i, shift):
100100
new_reshape_shape = [1] * shift + graph.get_shape(node.output[0])
101101
graph.insert_node_on_output(squeeze_node, node.output[0])
102102
const_shape = graph.make_const(utils.make_name(node.name + "_shape"), np.array(new_shape, np.int64)).output[0]
103+
if inp_shape == [-1] and len(new_shape) > 1:
104+
# This is a mismatch.
105+
return False
103106
if new_reshape_shape is not None:
104107
graph.set_shape(node.output[0], new_reshape_shape)
105108
graph.replace_inputs(node, [node.input[0], const_shape])
106109
if shift < 0:
107110
unsqueeze_node = GraphBuilder(graph).make_unsqueeze({'data': node.input[0], 'axes': list(range(-shift))})
108111
graph.replace_inputs(node, [unsqueeze_node, const_shape])
109-
110112
return True

0 commit comments

Comments
 (0)