66# pyre-unsafe
77
88import torch
9+ from executorch .backends .transforms .utils import (
10+ create_constant_placeholder ,
11+ delete_constant_placeholder ,
12+ )
913from executorch .exir import ExportedProgram
1014from executorch .exir .dialects ._ops import ops as exir_ops
1115from executorch .exir .pass_base import ExportPass , PassResult
1216from torch ._export .utils import get_buffer , get_param
17+ from torch .export .graph_signature import InputKind
1318from torch .fx import Node
1419from torch .nn .utils .fusion import fuse_conv_bn_weights
1520
@@ -23,7 +28,7 @@ def __init__(self, exported_program: ExportedProgram):
2328 self .exported_program = exported_program
2429 super ().__init__ ()
2530
26- def is_fuseable_conv_bn (self , node : Node ):
31+ def is_fuseable_conv_bn (self , node : Node ) -> bool :
2732 """Returns True if node is a batchnorm that can be fused into
2833 a parent convolution."""
2934 if node .op != "call_function" :
@@ -44,15 +49,19 @@ def is_fuseable_conv_bn(self, node: Node):
4449 # Since we change the output of the conv, fuse only if it has single user.
4550 if len (conv .users ) > 1 :
4651 return False
47- # For similar reasons, only fuse if conv parameters have single user.
48- if len (conv .all_input_nodes [1 ].users ) > 1 :
49- return False
50- if len (conv .all_input_nodes ) > 2 and len (conv .all_input_nodes [2 ].users ) > 1 :
51- return False
5252 return True
5353
54+ def get_bias_name (self , conv_weight_node : Node , conv_bias_node : Node ) -> str :
55+ if conv_bias_node :
56+ return conv_bias_node .name + "_fused_bn"
57+ elif "weight" in conv_weight_node .name :
58+ return conv_weight_node .name .replace ("weight" , "bias" ) + "_fused_bn"
59+ else :
60+ return conv_weight_node .name + "_bias_fused_bn"
61+
5462 def call (self , graph_module : torch .fx .GraphModule ) -> PassResult : # noqa: C901
5563 modified = False
64+ constant_placeholders_to_delete = set ()
5665 for node in graph_module .graph .nodes :
5766 if not self .is_fuseable_conv_bn (node ):
5867 continue
@@ -64,68 +73,93 @@ def get_param_or_none(arg) -> torch.nn.Parameter | None:
6473 )
6574
6675 # Get weight, bias, mean, var and epsilon from the batchnorm
67- bn = node
68- conv , bn_weight_node , bn_bias_node , bn_mean_node , bn_var_node = bn .args [0 :5 ]
69- bn_weight = get_param_or_none (bn_weight_node )
70- bn_bias = get_param_or_none (bn_bias_node )
71-
72- running_mean = get_buffer (self .exported_program , bn_mean_node )
73- running_var = get_buffer (self .exported_program , bn_var_node )
74- if running_mean is None or running_var is None :
76+ bn_node = node
77+ conv , bn_weight_node , bn_bias_node , bn_mean_node , bn_var_node = (
78+ bn_node .args [0 :5 ]
79+ )
80+ bn_weight_tensor = get_param_or_none (bn_weight_node )
81+ bn_bias_tensor = get_param_or_none (bn_bias_node )
82+ bn_mean_tensor = get_buffer (self .exported_program , bn_mean_node )
83+ bn_var_tensor = get_buffer (self .exported_program , bn_var_node )
84+ if bn_mean_tensor is None or bn_var_tensor is None :
7585 raise ValueError (
7686 "Parameters running_mean and running_var of batchnorm can't be None."
7787 )
78- epsilon = bn .args [- 1 ]
88+ epsilon = bn_node .args [- 1 ]
7989
8090 # Get weight and bias from conv
8191 conv_weight_node , conv_bias_node = conv .args [1 :3 ]
82- conv_weight = get_param (self .exported_program , conv_weight_node )
83- conv_bias = get_param_or_none (conv_bias_node )
84- if conv_weight is None :
92+ conv_weight_tensor = get_param (self .exported_program , conv_weight_node )
93+ conv_bias_tensor = get_param_or_none (conv_bias_node )
94+ if conv_weight_tensor is None :
8595 raise ValueError ("Parameter weight of convolution can't be None." )
8696
8797 # Compute conv parameters folded with batchnorm
8898 fused_conv_weight , fused_conv_bias = fuse_conv_bn_weights (
89- conv_weight ,
90- conv_bias ,
91- running_mean ,
92- running_var ,
99+ conv_weight_tensor ,
100+ conv_bias_tensor ,
101+ bn_mean_tensor ,
102+ bn_var_tensor ,
93103 epsilon ,
94- bn_weight ,
95- bn_bias ,
104+ bn_weight_tensor ,
105+ bn_bias_tensor ,
96106 )
97107
98- # Set the conv parameters to fused value
99- def try_set_param (
100- param_node : Node | None , param_value : torch .nn .Parameter
101- ) -> bool :
102- """set_param but check if param_node is None first. Return True if param was set successfully, otherwise False."""
103- if param_node is not None :
104- param_name = (
105- self .exported_program .graph_signature .inputs_to_parameters [
106- param_node .name
107- ]
108+ # Create fused weights and bias to conv and replace conv args
109+ with graph_module .graph .inserting_before (conv_weight_node ):
110+ fused_conv_weight_node = create_constant_placeholder (
111+ exp_program = self .exported_program ,
112+ graph = graph_module .graph ,
113+ kind = InputKind .PARAMETER ,
114+ name = conv_weight_node .name + "_fused_bn" ,
115+ data = fused_conv_weight ,
116+ )
117+
118+ if fused_conv_bias is not None :
119+ fused_conv_bias_node = create_constant_placeholder (
120+ exp_program = self .exported_program ,
121+ graph = graph_module .graph ,
122+ kind = InputKind .PARAMETER ,
123+ name = self .get_bias_name (conv_weight_node , conv_bias_node ),
124+ data = fused_conv_bias ,
108125 )
109- self .exported_program .state_dict [param_name ] = param_value
110- return True
111- return False
126+ else :
127+ fused_conv_bias_node = None
128+
129+ conv .args = (
130+ conv .args [0 ],
131+ fused_conv_weight_node ,
132+ fused_conv_bias_node ,
133+ * conv .args [3 :],
134+ )
112135
113- try_set_param (conv_weight_node , fused_conv_weight )
114- if not try_set_param (conv_bias_node , fused_conv_bias ) and try_set_param (
115- bn_bias_node , fused_conv_bias
116- ):
117- # pyre-ignore[60]
118- # Conv didn't have bias but batchnorm did, steal bias from batchnorm.
119- conv_args = (* conv .args [0 :2 ], bn_bias_node , * conv .args [3 :])
120- conv .args = conv_args
121-
122- # Erasing nodes is handled by dead-code elimination.
123- for user in bn .users :
136+ # Erasing batch-norm nodes is handled by dead-code elimination. After that we may remove their constant placeholder inputs
137+ for user in bn_node .users :
124138 user .replace_all_uses_with (conv )
139+
140+ constant_placeholders_to_delete .update (
141+ [
142+ bn_weight_node ,
143+ bn_bias_node ,
144+ bn_mean_node ,
145+ bn_var_node ,
146+ conv_weight_node ,
147+ conv_bias_node ,
148+ ]
149+ )
125150 modified = True
126151
127152 if modified :
128153 graph_module .graph .eliminate_dead_code ()
154+ for constant_placeholder in constant_placeholders_to_delete :
155+ if (constant_placeholder is not None ) and (
156+ len (constant_placeholder .users ) == 0
157+ ):
158+ delete_constant_placeholder (
159+ self .exported_program , constant_placeholder
160+ )
161+
129162 graph_module .recompile ()
130163 graph_module = super ().call (graph_module ).graph_module
164+
131165 return PassResult (graph_module = graph_module , modified = modified )
0 commit comments