@@ -755,6 +755,8 @@ def onnx_type_str_to_enum(dtype: str) -> int:
755
755
def remove_node_training_mode (onnx_model : onnx .ModelProto , node_op_type : str ) -> onnx .ModelProto :
756
756
"""Remove `training_mode` attribute and extra training outputs from nodes of a given op type.
757
757
758
+ This also removes the unused outputs from the training_mode nodes.
759
+
758
760
Args:
759
761
onnx_model: The onnx model.
760
762
node_op_type: The node type to remove training_mode attribute from.
@@ -763,33 +765,38 @@ def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) ->
763
765
The onnx model with the training_mode attribute removed.
764
766
"""
765
767
removed_output_names = set ()
768
+ all_inputs = {inp for n in onnx_model .graph .node for inp in n .input }
769
+ graph_outputs = {o .name for o in onnx_model .graph .output }
770
+ keep = all_inputs | graph_outputs
766
771
767
772
for node in onnx_model .graph .node :
768
773
if node .op_type != node_op_type :
769
774
continue
770
775
776
+ is_training_mode = False
771
777
# Drop the 'training_mode' attribute if present
772
778
for idx , attr in enumerate (list (node .attribute )):
773
779
if attr .name == "training_mode" :
774
780
del node .attribute [idx ]
781
+ if attr .i == 1 :
782
+ is_training_mode = True
775
783
break
776
784
777
- # If node has extra training outputs, keep only the first
778
- if len (node .output ) > 1 :
779
- removed_output_names .update (node .output [1 :])
780
- node .output [:] = node .output [:1 ]
785
+ # If the node has extra outputs, remove them all including the training outputs
786
+ if is_training_mode :
787
+ to_remove = []
788
+ for name in node .output :
789
+ if name not in keep :
790
+ removed_output_names .add (name )
791
+ to_remove .append (name )
792
+
793
+ for name in to_remove :
794
+ node .output .remove (name )
781
795
782
796
if removed_output_names :
783
797
# Clean up corresponding value_info entries
784
798
keep = [vi for vi in onnx_model .graph .value_info if vi .name not in removed_output_names ]
785
799
del onnx_model .graph .value_info [:]
786
800
onnx_model .graph .value_info .extend (keep )
787
801
788
- # Also clean up graph.output entries
789
- keep_outputs = [
790
- out for out in onnx_model .graph .output if out .name not in removed_output_names
791
- ]
792
- del onnx_model .graph .output [:]
793
- onnx_model .graph .output .extend (keep_outputs )
794
-
795
802
return onnx_model
0 commit comments