Skip to content

Commit f6197d6

Browse files
committed
Fix transpose optimizer to convert simple FusedBatchNormV3 models correctly. The Identity and Transpose handlers require to update both the subgraph as well as parent graph, and also cases where an op in parent graph is consumed in subgraph.
1 parent 4a9d8d1 commit f6197d6

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,17 @@ def _add_handler(self, trans, node):
406406

407407
def _transpose_handler(self, trans, node):
408408
if is_nchw_transpose(node):
409-
ops = self._g.get_nodes()
410-
self._g.replace_all_inputs(ops, node.output[0], trans.input[0])
411-
412-
shape = self._g.get_shape(node.output[0])
413-
dtype = self._g.get_dtype(node.output[0])
409+
for g in {self._g, node.graph}:
410+
ops = g.get_nodes()
411+
g.replace_all_inputs(ops, node.output[0], trans.input[0])
412+
413+
shape = node.graph.get_shape(node.output[0])
414+
dtype = node.graph.get_dtype(node.output[0])
415+
if node.output[0] in node.graph.outputs:
416+
node.graph.make_node("Identity", [trans.input[0]],
417+
outputs=node.output, shapes=[shape], dtypes=[dtype])
414418
self._g.remove_node(trans.name)
415-
self._g.remove_node(node.name)
416-
if node.output[0] in self._g.outputs:
417-
self._g.make_node("Identity", [trans.input[0]],
418-
outputs=node.output, shapes=[shape], dtypes=[dtype])
419+
node.graph.remove_node(node.name)
419420
return True
420421
return False
421422

@@ -459,11 +460,12 @@ def _mul_handler(self, trans, node):
459460
return False
460461

461462
def _identity_handler(self, trans, node):
462-
if node.output[0] in self._g.outputs:
463+
if node.output[0] in node.graph.outputs:
463464
return False
464-
ops = self._g.get_nodes()
465-
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
466-
self._g.remove_node(node.name)
465+
for g in {self._g, node.graph}:
466+
ops = g.get_nodes()
467+
g.replace_all_inputs(ops, node.output[0], trans.output[0])
468+
node.graph.remove_node(node.name)
467469
return True
468470

469471
def _concat_handler(self, trans, node):

0 commit comments

Comments
 (0)