Skip to content

Commit 560cfc1

Browse files
committed
enhance tranpose opt with squeeze
1 parent 4d0fced commit 560cfc1

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,23 +500,27 @@ def _calculate_new_attr(ori_perm, ori_squeeze_axes):
500500
return False
501501

502502
if node.get_attr("axes"):
503-
squeeze_axes = sorted(list(node.get_attr("axes").ints))
504-
trans_perm = list(trans.get_attr("perm").ints)
505-
squeeze_shape = self._g.get_shape(node.output[0])
506503
# switch tran and squeeze
507504
# 1 switch
508505
ops = self._g.get_nodes()
509506
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
510507
node.input[0] = trans.input[0]
511508
trans.input[0] = node.output[0]
512509
# 2 correct attr of nodes
510+
squeeze_axes = sorted(list(node.get_attr("axes").ints))
511+
trans_perm = list(trans.get_attr("perm").ints)
513512
new_perm, new_squeeze_axes = _calculate_new_attr(ori_perm=trans_perm, ori_squeeze_axes=squeeze_axes)
514513
trans.set_attr("perm", new_perm)
515514
node.set_attr("axes", new_squeeze_axes)
516515
# 3 set shape
516+
squeeze_shape = self._g.get_shape(node.output[0])
517517
self._g.set_shape(trans.output[0], squeeze_shape)
518518
input_shape = self._g.get_shape(node.input[0])
519-
new_squeeze_output_shape = [input_shape[i] for i in range(4) if i not in new_squeeze_axes]
519+
if input_shape is not None:
520+
new_squeeze_output_shape = [input_shape[i] for i in range(4) if i not in new_squeeze_axes]
521+
else:
522+
new_squeeze_output_shape = [-1]*4
523+
self.logger.warning("%s'sshape is unknown, which may interfere further optimization", node.input[0])
520524
self._g.set_shape(node.output[0], new_squeeze_output_shape)
521525
return True
522526
return False

0 commit comments

Comments
 (0)