|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import operator |
| 8 | +from typing import Optional |
| 9 | + |
| 10 | +import torch |
| 11 | +from executorch.backends.transforms.utils import ( |
| 12 | + create_constant_placeholder, |
| 13 | + delete_constant_placeholder, |
| 14 | +) |
| 15 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 16 | +from executorch.backends.xnnpack.utils.utils import ( |
| 17 | + get_param_tensor, |
| 18 | + get_tensor_name, |
| 19 | + is_param_node, |
| 20 | +) |
| 21 | +from executorch.exir import ExportedProgram |
| 22 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 23 | +from executorch.exir.pass_base import PassResult |
| 24 | +from torch.export.graph_signature import InputKind |
| 25 | + |
| 26 | + |
| 27 | +class ConvertBatchNormToDepthwiseConvPass(XNNPACKPass): |
| 28 | + """ |
| 29 | + Converts standalone batch norm operations to depthwise convolutions. |
| 30 | + This allows XNNPACK to handle batch norm operations that cannot be fused |
| 31 | + with preceding convolutions. |
| 32 | + |
| 33 | + BatchNorm formula: y = (x - mean) / sqrt(var + eps) * weight + bias |
| 34 | + This can be represented as a 1x1 depthwise convolution with: |
| 35 | + - conv_weight = weight / sqrt(var + eps) |
| 36 | + - conv_bias = bias - mean * weight / sqrt(var + eps) |
| 37 | + """ |
| 38 | + |
| 39 | + def call(self, graph_module: torch.fx.GraphModule): |
| 40 | + graph = graph_module.graph |
| 41 | + constant_placeholders_to_delete = set() |
| 42 | + nodes_to_convert = [] |
| 43 | + |
| 44 | + # First pass: identify standalone batch norm nodes |
| 45 | + for node in graph.nodes: |
| 46 | + if ( |
| 47 | + node.target != exir_ops.edge.aten._native_batch_norm_legit_no_training.default |
| 48 | + and node.target != exir_ops.edge.aten.native_batch_norm.default |
| 49 | + ): |
| 50 | + continue |
| 51 | + |
| 52 | + # Check if this batch norm can be fused with a preceding conv |
| 53 | + # If so, skip it - the fusion pass will handle it |
| 54 | + if self._can_be_fused_with_conv(node): |
| 55 | + continue |
| 56 | + |
| 57 | + # Check if this is a valid standalone batch norm to convert |
| 58 | + if self._can_convert_to_depthwise_conv(node): |
| 59 | + nodes_to_convert.append(node) |
| 60 | + |
| 61 | + # Second pass: convert the identified nodes |
| 62 | + for bn_node in nodes_to_convert: |
| 63 | + conv_node = self._convert_batch_norm_to_depthwise_conv( |
| 64 | + graph_module, bn_node, constant_placeholders_to_delete |
| 65 | + ) |
| 66 | + if conv_node is not None: |
| 67 | + # Replace all uses of batch norm getitem(0) with the conv node |
| 68 | + for user in list(bn_node.users): |
| 69 | + if user.target == operator.getitem and user.args[1] == 0: |
| 70 | + user.replace_all_uses_with(conv_node) |
| 71 | + graph.erase_node(user) |
| 72 | + |
| 73 | + # Remove the batch norm node |
| 74 | + graph.erase_node(bn_node) |
| 75 | + |
| 76 | + # Clean up unused constant placeholders |
| 77 | + if constant_placeholders_to_delete: |
| 78 | + graph_module.graph.eliminate_dead_code() |
| 79 | + for node in constant_placeholders_to_delete: |
| 80 | + if node is not None and len(node.users) == 0: |
| 81 | + delete_constant_placeholder(self.exported_program, node) |
| 82 | + |
| 83 | + graph_module.recompile() |
| 84 | + # Regenerate metadata and shape information |
| 85 | + graph_module = super().call(graph_module).graph_module |
| 86 | + |
| 87 | + return PassResult(graph_module, True) |
| 88 | + |
| 89 | + def _can_be_fused_with_conv(self, bn_node: torch.fx.Node) -> bool: |
| 90 | + """Check if this batch norm can be fused with a preceding convolution.""" |
| 91 | + # Import here to avoid circular dependency |
| 92 | + from executorch.backends.xnnpack._passes.fuse_batch_norm_with_conv import ( |
| 93 | + FuseBatchNormWithConvPass, |
| 94 | + ) |
| 95 | + |
| 96 | + input_node = bn_node.all_input_nodes[0] |
| 97 | + |
| 98 | + # Check if input is a conv with single user (this batch norm) |
| 99 | + if ( |
| 100 | + input_node.target == exir_ops.edge.aten.convolution.default |
| 101 | + and len(input_node.users) == 1 |
| 102 | + ): |
| 103 | + return FuseBatchNormWithConvPass.can_fuse( |
| 104 | + input_node, bn_node, self.exported_program |
| 105 | + ) |
| 106 | + |
| 107 | + return False |
| 108 | + |
| 109 | + def _can_convert_to_depthwise_conv(self, bn_node: torch.fx.Node) -> bool: |
| 110 | + """Check if this batch norm can be converted to depthwise conv.""" |
| 111 | + |
| 112 | + # All users must be getitem ops accessing the first element (output tensor) |
| 113 | + for user in bn_node.users: |
| 114 | + if user.target != operator.getitem or user.args[1] != 0: |
| 115 | + return False |
| 116 | + |
| 117 | + # Check that we have the required parameters |
| 118 | + if len(bn_node.args) < 5: |
| 119 | + return False |
| 120 | + |
| 121 | + # Weight, bias, running_mean, running_var must be parameters |
| 122 | + param_nodes = bn_node.args[1:5] # weight, bias, running_mean, running_var |
| 123 | + |
| 124 | + for param_node in param_nodes: |
| 125 | + if not isinstance(param_node, torch.fx.Node): |
| 126 | + return False |
| 127 | + if not is_param_node(self.exported_program, param_node): |
| 128 | + return False |
| 129 | + |
| 130 | + return True |
| 131 | + |
| 132 | + def _convert_batch_norm_to_depthwise_conv( |
| 133 | + self, |
| 134 | + graph_module: torch.fx.GraphModule, |
| 135 | + bn_node: torch.fx.Node, |
| 136 | + constant_placeholders_to_delete: set, |
| 137 | + ) -> Optional[torch.fx.Node]: |
| 138 | + """Convert a batch norm node to a depthwise convolution.""" |
| 139 | + |
| 140 | + # Extract batch norm parameters |
| 141 | + input_tensor = bn_node.args[0] |
| 142 | + |
| 143 | + # Cast args to Node types for parameter access |
| 144 | + bn_weight_node = bn_node.args[1] if isinstance(bn_node.args[1], torch.fx.Node) else None |
| 145 | + bn_bias_node = bn_node.args[2] if isinstance(bn_node.args[2], torch.fx.Node) else None |
| 146 | + running_mean_node = bn_node.args[3] if isinstance(bn_node.args[3], torch.fx.Node) else None |
| 147 | + running_var_node = bn_node.args[4] if isinstance(bn_node.args[4], torch.fx.Node) else None |
| 148 | + |
| 149 | + if any(node is None for node in [bn_weight_node, bn_bias_node, running_mean_node, running_var_node]): |
| 150 | + return None |
| 151 | + |
| 152 | + # These are guaranteed to be non-None now |
| 153 | + assert bn_weight_node is not None |
| 154 | + assert bn_bias_node is not None |
| 155 | + assert running_mean_node is not None |
| 156 | + assert running_var_node is not None |
| 157 | + |
| 158 | + bn_weight = get_param_tensor(self.exported_program, bn_weight_node) |
| 159 | + bn_bias = get_param_tensor(self.exported_program, bn_bias_node) |
| 160 | + running_mean = get_param_tensor(self.exported_program, running_mean_node) |
| 161 | + running_var = get_param_tensor(self.exported_program, running_var_node) |
| 162 | + |
| 163 | + # Get epsilon value |
| 164 | + if str(bn_node.target).endswith("native_batch_norm.default"): |
| 165 | + eps = bn_node.args[7] if len(bn_node.args) > 7 else 1e-5 |
| 166 | + else: # _native_batch_norm_legit_no_training |
| 167 | + eps = bn_node.args[6] if len(bn_node.args) > 6 else 1e-5 |
| 168 | + |
| 169 | + # Ensure eps is a float |
| 170 | + if not isinstance(eps, (int, float)): |
| 171 | + eps = 1e-5 |
| 172 | + |
| 173 | + if any(param is None for param in [bn_weight, bn_bias, running_mean, running_var]): |
| 174 | + return None |
| 175 | + |
| 176 | + # Ensure all parameters are tensors |
| 177 | + assert isinstance(bn_weight, torch.Tensor) |
| 178 | + assert isinstance(bn_bias, torch.Tensor) |
| 179 | + assert isinstance(running_mean, torch.Tensor) |
| 180 | + assert isinstance(running_var, torch.Tensor) |
| 181 | + |
| 182 | + # Calculate depthwise conv parameters |
| 183 | + # BatchNorm: y = (x - mean) / sqrt(var + eps) * weight + bias |
| 184 | + # Depthwise Conv: y = x * conv_weight + conv_bias |
| 185 | + # Therefore: conv_weight = weight / sqrt(var + eps) |
| 186 | + # conv_bias = bias - mean * weight / sqrt(var + eps) |
| 187 | + |
| 188 | + inv_std = torch.rsqrt(running_var + eps) |
| 189 | + conv_weight_1d = bn_weight * inv_std |
| 190 | + conv_bias_1d = bn_bias - running_mean * conv_weight_1d |
| 191 | + |
| 192 | + # Reshape for depthwise conv: [C] -> [C, 1, 1, 1] for 2D conv |
| 193 | + # Assuming 4D input tensor [N, C, H, W] |
| 194 | + num_channels = conv_weight_1d.shape[0] |
| 195 | + conv_weight = conv_weight_1d.view(num_channels, 1, 1, 1) |
| 196 | + conv_bias = conv_bias_1d |
| 197 | + |
| 198 | + # Create parameter names |
| 199 | + bn_weight_name = get_tensor_name(self.exported_program, bn_weight_node) |
| 200 | + conv_weight_name = (bn_weight_name + "_as_depthwise_conv_weight").replace(".", "_") |
| 201 | + conv_bias_name = (bn_weight_name + "_as_depthwise_conv_bias").replace(".", "_") |
| 202 | + |
| 203 | + # Create new parameter nodes |
| 204 | + graph = graph_module.graph |
| 205 | + with graph.inserting_before(bn_node): |
| 206 | + conv_weight_node = create_constant_placeholder( |
| 207 | + exp_program=self.exported_program, |
| 208 | + graph=graph, |
| 209 | + kind=InputKind.PARAMETER, |
| 210 | + name=conv_weight_name, |
| 211 | + data=conv_weight, |
| 212 | + ) |
| 213 | + |
| 214 | + conv_bias_node = create_constant_placeholder( |
| 215 | + exp_program=self.exported_program, |
| 216 | + graph=graph, |
| 217 | + kind=InputKind.PARAMETER, |
| 218 | + name=conv_bias_name, |
| 219 | + data=conv_bias, |
| 220 | + ) |
| 221 | + |
| 222 | + # Create depthwise convolution node |
| 223 | + # Args: input, weight, bias, stride, padding, dilation, transposed, output_padding, groups |
| 224 | + conv_args = ( |
| 225 | + input_tensor, # input |
| 226 | + conv_weight_node, # weight |
| 227 | + conv_bias_node, # bias |
| 228 | + [1, 1], # stride |
| 229 | + [0, 0], # padding |
| 230 | + [1, 1], # dilation |
| 231 | + False, # transposed |
| 232 | + [0, 0], # output_padding |
| 233 | + num_channels, # groups (depthwise = groups = in_channels) |
| 234 | + ) |
| 235 | + |
| 236 | + conv_node = graph.create_node( |
| 237 | + "call_function", |
| 238 | + exir_ops.edge.aten.convolution.default, |
| 239 | + args=conv_args, |
| 240 | + ) |
| 241 | + |
| 242 | + # Mark old parameters for deletion |
| 243 | + constant_placeholders_to_delete.update(bn_node.args[1:5]) |
| 244 | + |
| 245 | + return conv_node |
| 246 | + |
| 247 | + @staticmethod |
| 248 | + def can_convert_standalone_batch_norm( |
| 249 | + bn_node: torch.fx.Node, program: ExportedProgram |
| 250 | + ) -> bool: |
| 251 | + """ |
| 252 | + Static method to check if a standalone batch norm can be converted. |
| 253 | + Used by the partitioner configuration. |
| 254 | + """ |
| 255 | + # All users must be getitem ops accessing the first element |
| 256 | + for user in bn_node.users: |
| 257 | + if user.target != operator.getitem or user.args[1] != 0: |
| 258 | + return False |
| 259 | + |
| 260 | + # Check that we have required parameters |
| 261 | + if len(bn_node.args) < 5: |
| 262 | + return False |
| 263 | + |
| 264 | + # Weight, bias, running_mean, running_var must be parameters |
| 265 | + param_nodes = bn_node.args[1:5] |
| 266 | + |
| 267 | + for param_node in param_nodes: |
| 268 | + if not isinstance(param_node, torch.fx.Node): |
| 269 | + return False |
| 270 | + if not is_param_node(program, param_node): |
| 271 | + return False |
| 272 | + |
| 273 | + return True |
0 commit comments