@@ -755,6 +755,8 @@ def onnx_type_str_to_enum(dtype: str) -> int:
755755def remove_node_training_mode (onnx_model : onnx .ModelProto , node_op_type : str ) -> onnx .ModelProto :
756756 """Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
757757
758+ This also removes the unused outputs from the training_mode nodes.
759+
758760 Args:
759761 onnx_model: The onnx model.
760762 node_op_type: The node type to remove training_mode attribute from.
@@ -763,33 +765,38 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) ->
763765 The onnx model with the training_mode attribute removed.
764766 """
765767 removed_output_names = set ()
768+ all_inputs = {inp for n in onnx_model .graph .node for inp in n .input }
769+ graph_outputs = {o .name for o in onnx_model .graph .output }
770+ keep = all_inputs | graph_outputs
766771
767772 for node in onnx_model .graph .node :
768773 if node .op_type != node_op_type :
769774 continue
770775
776+ is_training_mode = False
771777 # Drop the 'training_mode' attribute if present
772778 for idx , attr in enumerate (list (node .attribute )):
773779 if attr .name == "training_mode" :
774780 del node .attribute [idx ]
781+ if attr .i == 1 :
782+ is_training_mode = True
775783 break
776784
777- # If node has extra training outputs, keep only the first
778- if len (node .output ) > 1 :
779- removed_output_names .update (node .output [1 :])
780- node .output [:] = node .output [:1 ]
785+ # If the node has extra outputs, remove them all including the training outputs
786+ if is_training_mode :
787+ to_remove = []
788+ for name in node .output :
789+ if name not in keep :
790+ removed_output_names .add (name )
791+ to_remove .append (name )
792+
793+ for name in to_remove :
794+ node .output .remove (name )
781795
782796 if removed_output_names :
783797 # Clean up corresponding value_info entries
784798 keep = [vi for vi in onnx_model .graph .value_info if vi .name not in removed_output_names ]
785799 del onnx_model .graph .value_info [:]
786800 onnx_model .graph .value_info .extend (keep )
787801
788- # Also clean up graph.output entries
789- keep_outputs = [
790- out for out in onnx_model .graph .output if out .name not in removed_output_names
791- ]
792- del onnx_model .graph .output [:]
793- onnx_model .graph .output .extend (keep_outputs )
794-
795802 return onnx_model
0 commit comments