Skip to content

Commit 5558957

Browse files
authored
Merge pull request #963 from daquexian/update_outdated_shape
update the out-dated shape in _handle_node_having_branches()
2 parents cb016ef + 168d2a9 commit 5558957

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,12 @@ def _handle_node_having_branches(self, node):
222222
utils.make_sure(len(n.output) == 1, "only expect single output")
223223
self._g.replace_all_inputs(self._g.get_nodes(), n.output[0], n_input)
224224
self._g.remove_node(n.name)
225+
226+
shape = self._g.get_shape(node.output[0])
227+
if shape:
228+
# only nhwc transpose can reach here
229+
new_shape = [shape[i] for i in NHWC_TO_NCHW]
230+
self._g.set_shape(node.output[0], new_shape)
225231
return True
226232

227233
self.logger.debug("input transpose does not have single consumer, skipping...")

0 commit comments

Comments
 (0)