77import operator
88
99import torch
10+ from executorch .backends .transforms .utils import (
11+ create_constant_placeholder ,
12+ delete_constant_placeholder ,
13+ )
1014
1115from executorch .backends .xnnpack ._passes .xnnpack_pass import XNNPACKPass
1216
13- from executorch .backends .xnnpack .utils .utils import get_param_tensor , is_param_node
17+ from executorch .backends .xnnpack .utils .utils import (
18+ get_param_tensor ,
19+ get_tensor_name ,
20+ is_param_node ,
21+ )
1422from executorch .exir import ExportedProgram
1523from executorch .exir .dialects ._ops import ops as exir_ops
1624from executorch .exir .pass_base import PassResult
25+ from torch .export .graph_signature import InputKind
1726
1827from torch .nn .utils .fusion import fuse_conv_bn_weights
1928
@@ -28,7 +37,7 @@ class FuseBatchNormWithConvPass(XNNPACKPass):
2837
2938 def call (self , graph_module : torch .fx .GraphModule ):
3039 graph = graph_module .graph
31- counter = 0
40+ constant_placeholders_to_delete = set ()
3241 for conv in graph .nodes :
3342 # We want to discover a chain of conv -> batch_norm.
3443 # Only proceed if the current node is a conv node, and has a single
@@ -55,9 +64,11 @@ def call(self, graph_module: torch.fx.GraphModule):
5564 assert len (conv .args ) == 9
5665
5766 conv_weight = get_param_tensor (self .exported_program , conv .args [1 ])
67+ conv_weight_name = get_tensor_name (self .exported_program , conv .args [1 ])
5868 assert conv_weight is not None
5969
6070 conv_bias = get_param_tensor (self .exported_program , conv .args [2 ])
71+ conv_bias_name = get_tensor_name (self .exported_program , conv .args [2 ])
6172
6273 # Get the parameters from the batchnorm op
6374 assert (
@@ -95,32 +106,57 @@ def call(self, graph_module: torch.fx.GraphModule):
95106 bn_bias ,
96107 is_transpose ,
97108 )
109+ fused_weight_name = (conv_weight_name + "_fused_bn" ).replace ("." , "_" )
110+ if conv_bias_name == "" :
111+ fused_bias_name = (conv_weight_name + "_bias_fused_bn" ).replace (
112+ "." , "_"
113+ )
114+ else :
115+ fused_bias_name = (conv_bias_name + "_fused_bn" ).replace ("." , "_" )
98116
99117 # Modify the graph by updating the weight and bias of conv op
100118 # with the fused weight and bias params, and replacing all the users
101119 # of getitem(batchnorm) with the conv op.
102- with graph .inserting_before (conv ):
103- fused_weight_name = f"_fused_with_bn_weight_{ counter } "
104- graph_module .register_parameter (fused_weight_name , fused_weight )
105- fused_weight_node = graph .get_attr (fused_weight_name )
106- fused_bias_name = f"_fused_with_bn_bias_{ counter } "
107- graph_module .register_parameter (fused_bias_name , fused_bias )
108- fused_bias_node = graph .get_attr (fused_bias_name )
109-
110- # Update the weight and bias of conv op
111- conv_args = list (conv .args ) + ([None ] if len (conv .args ) == 2 else [])
112- conv_args [1 ] = fused_weight_node
113- conv_args [2 ] = fused_bias_node
114- conv .args = tuple (conv_args )
120+ with graph .inserting_before (conv .args [1 ]):
121+ fused_conv_weight_node = create_constant_placeholder (
122+ exp_program = self .exported_program ,
123+ graph = graph_module .graph ,
124+ kind = InputKind .PARAMETER ,
125+ name = fused_weight_name ,
126+ data = fused_weight ,
127+ )
128+ if fused_bias is not None :
129+ fused_conv_bias_node = create_constant_placeholder (
130+ exp_program = self .exported_program ,
131+ graph = graph_module .graph ,
132+ kind = InputKind .PARAMETER ,
133+ name = fused_bias_name ,
134+ data = fused_bias ,
135+ )
136+ else :
137+ fused_conv_bias_node = None
138+
139+ conv .args = (
140+ conv .args [0 ],
141+ fused_conv_weight_node ,
142+ fused_conv_bias_node ,
143+ * conv .args [3 :],
144+ )
145+
115146 # Remove any use of batchnorm from the graph
116147 for user in bn .users .copy ():
117148 assert user .target == operator .getitem
118149 user .replace_all_uses_with (conv )
119150 graph .erase_node (user )
120151
121152 graph .erase_node (bn )
153+ constant_placeholders_to_delete .update (conv .args [1 :3 ] + bn .args [1 :5 ])
122154
123- counter += 1
155+ if len (constant_placeholders_to_delete ) > 0 :
156+ graph_module .graph .eliminate_dead_code ()
157+ for node in constant_placeholders_to_delete :
158+ if (node is not None ) and (len (node .users ) == 0 ):
159+ delete_constant_placeholder (self .exported_program , node )
124160
125161 graph_module .recompile ()
126162 # To Regenerate meta data and shape information, retrace module
0 commit comments