@@ -406,16 +406,17 @@ def _add_handler(self, trans, node):
406
406
407
407
def _transpose_handler (self , trans , node ):
408
408
if is_nchw_transpose (node ):
409
- ops = self ._g .get_nodes ()
410
- self ._g .replace_all_inputs (ops , node .output [0 ], trans .input [0 ])
411
-
412
- shape = self ._g .get_shape (node .output [0 ])
413
- dtype = self ._g .get_dtype (node .output [0 ])
409
+ for g in {self ._g , node .graph }:
410
+ ops = g .get_nodes ()
411
+ g .replace_all_inputs (ops , node .output [0 ], trans .input [0 ])
412
+
413
+ shape = node .graph .get_shape (node .output [0 ])
414
+ dtype = node .graph .get_dtype (node .output [0 ])
415
+ if node .output [0 ] in node .graph .outputs :
416
+ node .graph .make_node ("Identity" , [trans .input [0 ]],
417
+ outputs = node .output , shapes = [shape ], dtypes = [dtype ])
414
418
self ._g .remove_node (trans .name )
415
- self ._g .remove_node (node .name )
416
- if node .output [0 ] in self ._g .outputs :
417
- self ._g .make_node ("Identity" , [trans .input [0 ]],
418
- outputs = node .output , shapes = [shape ], dtypes = [dtype ])
419
+ node .graph .remove_node (node .name )
419
420
return True
420
421
return False
421
422
@@ -459,11 +460,12 @@ def _mul_handler(self, trans, node):
459
460
return False
460
461
461
462
def _identity_handler (self , trans , node ):
462
- if node .output [0 ] in self . _g .outputs :
463
+ if node .output [0 ] in node . graph .outputs :
463
464
return False
464
- ops = self ._g .get_nodes ()
465
- self ._g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
466
- self ._g .remove_node (node .name )
465
+ for g in {self ._g , node .graph }:
466
+ ops = g .get_nodes ()
467
+ g .replace_all_inputs (ops , node .output [0 ], trans .output [0 ])
468
+ node .graph .remove_node (node .name )
467
469
return True
468
470
469
471
def _concat_handler (self , trans , node ):
0 commit comments