Skip to content

Commit ef5522d

Browse files
authored
Merge pull request #747 from jignparm/jignparm/fix_fusedbatchnormv3
Fix transpose_optimizer.py to handle FusedBatchNormV3 models
2 parents 4a9d8d1 + f6197d6 commit ef5522d

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,17 @@ def _add_handler(self, trans, node):
406406

407407
def _transpose_handler(self, trans, node):
408408
if is_nchw_transpose(node):
409-
ops = self._g.get_nodes()
410-
self._g.replace_all_inputs(ops, node.output[0], trans.input[0])
411-
412-
shape = self._g.get_shape(node.output[0])
413-
dtype = self._g.get_dtype(node.output[0])
409+
for g in {self._g, node.graph}:
410+
ops = g.get_nodes()
411+
g.replace_all_inputs(ops, node.output[0], trans.input[0])
412+
413+
shape = node.graph.get_shape(node.output[0])
414+
dtype = node.graph.get_dtype(node.output[0])
415+
if node.output[0] in node.graph.outputs:
416+
node.graph.make_node("Identity", [trans.input[0]],
417+
outputs=node.output, shapes=[shape], dtypes=[dtype])
414418
self._g.remove_node(trans.name)
415-
self._g.remove_node(node.name)
416-
if node.output[0] in self._g.outputs:
417-
self._g.make_node("Identity", [trans.input[0]],
418-
outputs=node.output, shapes=[shape], dtypes=[dtype])
419+
node.graph.remove_node(node.name)
419420
return True
420421
return False
421422

@@ -459,11 +460,12 @@ def _mul_handler(self, trans, node):
459460
return False
460461

461462
def _identity_handler(self, trans, node):
462-
if node.output[0] in self._g.outputs:
463+
if node.output[0] in node.graph.outputs:
463464
return False
464-
ops = self._g.get_nodes()
465-
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
466-
self._g.remove_node(node.name)
465+
for g in {self._g, node.graph}:
466+
ops = g.get_nodes()
467+
g.replace_all_inputs(ops, node.output[0], trans.output[0])
468+
node.graph.remove_node(node.name)
467469
return True
468470

469471
def _concat_handler(self, trans, node):

0 commit comments

Comments
 (0)