@@ -506,6 +506,8 @@ def insert_node(self, node, before=None, input_idx=0):
506506
507507 if next_node is not None :
508508 next_node .inputs [input_idx ] = node .outputs [0 ]
509+ else :
510+ self .outputs = [node .outputs [0 ] if name == prev_node .outputs [0 ] else name for name in self .outputs ]
509511
510512 new_graph = OrderedDict ()
511513 for k , v in self .graph .items ():
@@ -514,47 +516,57 @@ def insert_node(self, node, before=None, input_idx=0):
514516 new_graph [node .name ] = node
515517
516518 self .graph = new_graph
517- self ._update_model_outputs ()
518519
519520 def remove_node (self , node , rewire = True ):
520- """Remove a node from a graph.
521+ """Removes a node from the graph.
521522
522- By default, this function can connect the outputs of previous node to the input of next one.
523- Note that when removing a leaf node `rewire` should be set to `False`.
523+ By default, this function connects the outputs of the previous
524+ node to the inputs of the next node. If the removed node has multiple
525+ input/output tensors, an exception is raised.
524526
525527 Args:
526- node (Layer): The node to remove
527- rewire (bool, optional): If `True`, connects the outputs of the previous node
528- to the inputs of the next node
528+ node (Layer): The node to remove.
529+ rewire (bool, optional): Deprecated, has no effect.
529530
530531 Raises:
531- Exception: If an attempt is made to rewire a leaf node or a node with multiple
532- inputs/outputs.
532+ Exception: If an attempt is made to rewire a node with
533+ multiple inputs/outputs.
533534
535+ Note:
536+ The `rewire` parameter is deprecated and has no effect.
534537 """
535- if rewire :
536- inputs = [inp for inp in node .inputs if inp ]
537- outputs = [outp for outp in node .outputs if outp ]
538- if len (inputs ) > 1 or len (outputs ) > 1 :
539- raise Exception ('Cannot rewire a node with multiple inputs/outputs' )
540- prev_node = node .get_input_node (node .inputs [0 ])
538+
539+ inputs = [inp for inp in node .inputs if inp ]
540+ outputs = [outp for outp in node .outputs if outp ]
541+
542+ if len (inputs ) > 1 or len (outputs ) > 1 :
543+ raise Exception ('Cannot delete a node with multiple inputs/outputs' )
544+
545+ if len (inputs ) == 1 :
546+ # Connect inputs -> $outputs
547+ if node .name in self .outputs :
548+ msg = f'Remove leaf node { node .name } will connect its input node { inputs [0 ]} to output, but it already is.'
549+ assert inputs [0 ] not in self .outputs , msg
550+ self .outputs = [inputs [0 ] if name == node .name else name for name in self .outputs ]
551+
552+ if len (outputs ) == 1 and len (inputs ) == 1 :
553+ inp_var = node .get_input_variable ()
554+ out_var = node .get_output_variable ()
555+
556+ # fmt: off
557+ assert (np .prod (inp_var .shape ) == np .prod (out_var .shape )), \
558+ f'Input and output shapes do not match for { node .name } : { inp_var .shape } -> { out_var .shape } '
559+ # fmt: on
560+
541561 next_nodes = [x for x in self .graph .values () if node .outputs [0 ] in x .inputs ]
542- if prev_node is not None :
543- if len (next_nodes ) > 0 :
544- for next_node in next_nodes :
545- for i , _ in enumerate (next_node .inputs ):
546- if node .outputs [0 ] == next_node .inputs [i ]:
547- next_node .inputs [i ] = prev_node .outputs [0 ]
548- break
549- else :
550- if not node .outputs [0 ] in self .outputs :
551- raise Exception ('Cannot rewire a node without child' )
552- else :
553- raise Exception ('Cannot rewire a node without a parent' )
562+ for next_node in next_nodes :
563+ # Connect inputs -> next
564+ for i , nxt_inp in enumerate (next_node .inputs ):
565+ if outputs [0 ] == nxt_inp :
566+ next_node .inputs [i ] = inputs [0 ]
554567
555568 del self .output_vars [node .outputs [0 ]]
556569 del self .graph [node .name ]
557- self ._update_model_outputs ()
558570
559571 def replace_node (self , old_node , new_node ):
560572 """Replace an existing node in the graph with a new one.
@@ -584,7 +596,11 @@ def replace_node(self, old_node, new_node):
584596 node .outputs [i ] = repl [n ]
585597
586598 self .graph = OrderedDict ((new_node .name , new_node ) if k == old_node .name else (k , v ) for k , v in self .graph .items ())
587- self ._update_model_outputs ()
599+
600+ old_name = old_node .name
601+ if old_name in self .outputs :
602+ new_name = new_node .name
603+ self .outputs = [new_name if name == old_name else name for name in self .outputs ]
588604
589605 def split_node (self , old_node , new_node1 , new_node2 ):
590606 """Replace an existing node in the graph with two nodes in sequence.
@@ -622,17 +638,9 @@ def split_node(self, old_node, new_node1, new_node2):
622638 else :
623639 new_graph [key ] = value
624640 self .graph = new_graph
625- self ._update_model_outputs ()
626-
627- def _update_model_outputs (self ):
628- '''Update the model outputs
629641
630- All node outputs and inputs are found. The model outputs are set to all node outputs
631- that are not also node inputs.
632- '''
633- node_outputs = [out for node in self .graph .values () for out in node .outputs ]
634- node_inputs = [inp for node in self .graph .values () for inp in node .inputs ]
635- self .outputs = [out for out in node_outputs if out not in node_inputs ]
642+ if old_node .name in self .outputs :
643+ self .outputs = [new_node2 .name if name == old_node .name else name for name in self .outputs ]
636644
637645 def next_layer (self ):
638646 self .index += 1
0 commit comments