|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | +import operator |
| 9 | + |
| 10 | +import torch |
| 11 | +from executorch.backends.xnnpack._passes.xnnpack_pass import XNNPACKPass |
| 12 | +from executorch.backends.xnnpack.utils.utils import ( |
| 13 | + check_or_raise, |
| 14 | + get_param_tensor, |
| 15 | + is_param_node, |
| 16 | +) |
| 17 | +from executorch.exir.backend.utils import WhyNoPartition |
| 18 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 19 | +from torch.fx.passes.infra.pass_base import PassResult |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | +logger.setLevel(logging.WARNING) |
| 23 | + |
| 24 | + |
| 25 | +class DecomposeBatchNorm(XNNPACKPass): |
| 26 | + """ |
| 27 | + Decompose batchnorm operators into 1x1 depthwise convolution. |
| 28 | + """ |
| 29 | + |
| 30 | + BATCH_NORM_OPS = { |
| 31 | + exir_ops.edge.aten.native_batch_norm.default, |
| 32 | + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, |
| 33 | + } |
| 34 | + |
| 35 | + @staticmethod |
| 36 | + def can_decompose_batch_norm( |
| 37 | + node: torch.fx.Node, |
| 38 | + exported_program: torch.export.ExportedProgram, |
| 39 | + why: WhyNoPartition | None = None, |
| 40 | + ) -> bool: |
| 41 | + """ |
| 42 | + Determine whether the given batch norm node can be decomposed by this pass. |
| 43 | + """ |
| 44 | + |
| 45 | + if node.op != "call_function" or node.target not in DecomposeBatchNorm.BATCH_NORM_OPS: |
| 46 | + return False |
| 47 | + |
| 48 | + input_meta = node.args[0].meta["val"] |
| 49 | + |
| 50 | + # Since we're converting to conv and XNNPACK doesn't support conv3d, we can't |
| 51 | + # handle BatchNorm3d. Validate the input dimension. We'll take NC, NCL, or NCHW. |
| 52 | + if input_meta.dim() not in (2, 3, 4): |
| 53 | + if why: |
| 54 | + why(node, f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.") |
| 55 | + return False |
| 56 | + |
| 57 | + # The batch norm node returns a tuple of output and other stuff we don't care about. |
| 58 | + # All users must be getitem nodes that fetches the output (index 0). |
| 59 | + # The partitioner should enforce this, but we'll check it here too. |
| 60 | + for user in node.users: |
| 61 | + if user.target != operator.getitem or user.args[1] != 0: |
| 62 | + if why: |
| 63 | + why(node, "Batch norm users must only access the output tensor.") |
| 64 | + return False |
| 65 | + |
| 66 | + # Channel dimension and non-input args must be statically known. |
| 67 | + if not isinstance(input_meta.shape[1], int): |
| 68 | + if why: |
| 69 | + why(node, f"Channel dimension must be statically known, but was {input_meta.shape[1]}.") |
| 70 | + return False |
| 71 | + |
| 72 | + if not is_param_node(exported_program, node.args[1]) or not is_param_node(exported_program, node.args[2]): |
| 73 | + if why: |
| 74 | + why(node, "Batch norm affine weight and bias must be static.") |
| 75 | + return False |
| 76 | + |
| 77 | + if not is_param_node(exported_program, node.args[3]) or not is_param_node(exported_program, node.args[4]): |
| 78 | + if why: |
| 79 | + why(node, "Batch norm running mean and variance must be static.") |
| 80 | + return False |
| 81 | + |
| 82 | + if isinstance(node.args[-1], torch.fx.Node): |
| 83 | + if why: |
| 84 | + why(node, "Batch norm epsilon must be static.") |
| 85 | + return False |
| 86 | + |
| 87 | + return True |
| 88 | + |
| 89 | + @staticmethod |
| 90 | + def compute_w_and_b( |
| 91 | + eps: float, |
| 92 | + running_mean: torch.Tensor, # [C] |
| 93 | + running_var: torch.Tensor, # [C] |
| 94 | + gamma: torch.Tensor, # [C], learned weight |
| 95 | + beta: torch.Tensor, # [C], learned bias |
| 96 | + ) -> (torch.Tensor, torch.Tensor): |
| 97 | + """ |
| 98 | + Compute equivalent per-channel weight and bias to match the batch norm |
| 99 | + computation with frozen values. |
| 100 | + """ |
| 101 | + |
| 102 | + # See https://docs.pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html |
| 103 | + denom = torch.sqrt(running_var + torch.Tensor([eps])) |
| 104 | + weight = gamma / denom |
| 105 | + bias = -running_mean * gamma / denom + beta |
| 106 | + |
| 107 | + return weight, bias |
| 108 | + |
| 109 | + @staticmethod |
| 110 | + def replace_bn_node_with_conv( |
| 111 | + bn_node: torch.fx.Node, |
| 112 | + graph_module: torch.fx.GraphModule, |
| 113 | + exported_program: torch.export.ExportedProgram, |
| 114 | + ) -> torch.fx.Node: |
| 115 | + """ |
| 116 | + Replace a BatchNorm with NCL or NCHW input with an equivalent depthwise |
| 117 | + convolution. |
| 118 | + """ |
| 119 | + |
| 120 | + # Compute the equivalent per-channel weights and biases. |
| 121 | + # Note that the batch norm node args are |
| 122 | + # (input, gamma, beta, running_mean, running_var, [training], momentum, eps). |
| 123 | + # The training arg is not present in the _no_training variant. |
| 124 | + weight, bias = DecomposeBatchNorm.compute_w_and_b( |
| 125 | + eps=bn_node.args[-1], |
| 126 | + running_mean=get_param_tensor(exported_program, bn_node.args[3]), |
| 127 | + running_var=get_param_tensor(exported_program, bn_node.args[4]), |
| 128 | + gamma=get_param_tensor(exported_program, bn_node.args[1]), |
| 129 | + beta=get_param_tensor(exported_program, bn_node.args[2]), |
| 130 | + ) |
| 131 | + |
| 132 | + with graph_module.graph.inserting_after(bn_node): |
| 133 | + # Conv weights have shape [out_c, in_c/g, spatial...]. |
| 134 | + # For dw, in_c = g. The kernel is also 1x1 (or just 1, for 1d). |
| 135 | + # |
| 136 | + # BatchNorm weights have shape [in_c]. |
| 137 | + # So we just need to unsqueeze the [in_c] to to [in_c, 1, 1, [1]]. |
| 138 | + input_meta = bn_node.args[0].meta["val"] |
| 139 | + channel_count = input_meta.shape[1] |
| 140 | + spatial_dims = max(input_meta.dim() - 2, 1) # Min of 1 since 1d can be NC or NCL. |
| 141 | + new_weight_shape = [weight.shape[0], 1] + [1] * spatial_dims |
| 142 | + weight = weight.reshape(new_weight_shape) |
| 143 | + |
| 144 | + # Insert the new weight and biases in the graph. |
| 145 | + # TODO? |
| 146 | + |
| 147 | + conv_node = graph_module.graph.call_function( |
| 148 | + exir_ops.edge.aten.convolution.default, |
| 149 | + args=( |
| 150 | + bn_node.args[0], # Input |
| 151 | + weight, # Weight |
| 152 | + bias, # Bias |
| 153 | + [1] * spatial_dims, # Stride |
| 154 | + [0] * spatial_dims, # Padding |
| 155 | + [1] * spatial_dims, # Dilation |
| 156 | + False, # Transposed |
| 157 | + [0] * spatial_dims, # Output_padding |
| 158 | + channel_count, # Groups (depthwise, so groups=in_channels) |
| 159 | + )) |
| 160 | + |
| 161 | + # Find the getitem user nodes and replace them with the conv node. |
| 162 | + # The decomp checks above enforce that the node is only used by getitem[0]. |
| 163 | + users = list(bn_node.users) |
| 164 | + for user in users: |
| 165 | + user.replace_all_uses_with(conv_node) |
| 166 | + graph_module.graph.erase_node(user) |
| 167 | + |
| 168 | + graph_module.graph.erase_node(bn_node) |
| 169 | + return conv_node |
| 170 | + |
| 171 | + |
| 172 | + def decompose_node(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None: |
| 173 | + input_meta = node.args[0].meta["val"] |
| 174 | + |
| 175 | + # These should be checked by the partitioner and calling node, |
| 176 | + # so we should never fail these checks. |
| 177 | + check_or_raise( |
| 178 | + node.op == "call_function" and node.target in DecomposeBatchNorm.BATCH_NORM_OPS, |
| 179 | + f"Invalid batch norm operator {node.op}.") |
| 180 | + |
| 181 | + check_or_raise( |
| 182 | + input_meta.dim() in (2, 3, 4), |
| 183 | + f"Unsupported input rank {input_meta.dim()} for XNN batch norm operator.") |
| 184 | + |
| 185 | + channel_count = input_meta.shape[1] |
| 186 | + check_or_raise( |
| 187 | + isinstance(channel_count, int), |
| 188 | + f"Channel dimension must be statically known, but was {channel_count}.") |
| 189 | + |
| 190 | + # Create the convolution node. |
| 191 | + conv_node = self.replace_bn_node_with_conv(node, graph_module, self.exported_program) |
| 192 | + |
| 193 | + # BatchNorm1d can be NC or NCL. Conv1d requies the L dim, so unsqueeze NC -> NCL. |
| 194 | + if input_meta.dim() == 2: |
| 195 | + with graph_module.graph.inserting_before(conv_node): |
| 196 | + # Insert unsqueeze node before. |
| 197 | + unsqueeze_node = graph_module.graph.call_function( |
| 198 | + exir_ops.edge.aten.unsqueeze_copy.default, |
| 199 | + args=(conv_node.args[0], 2)) |
| 200 | + conv_node.args = (unsqueeze_node, *conv_node.args[1:]) |
| 201 | + |
| 202 | + with graph_module.graph.inserting_after(conv_node): |
| 203 | + # Insert squeeze node after. |
| 204 | + squeeze_node = graph_module.graph.call_function( |
| 205 | + exir_ops.edge.aten.squeeze_copy.dim, |
| 206 | + args=(conv_node, 2)) |
| 207 | + conv_node.replace_all_uses_with(squeeze_node) |
| 208 | + # This gets overwritten by replace_all_uses_with. Maybe there's |
| 209 | + # a better solution? |
| 210 | + squeeze_node.args = (conv_node, *squeeze_node.args[1:]) |
| 211 | + |
| 212 | + # override |
| 213 | + def call(self, graph_module: torch.fx.GraphModule): |
| 214 | + # fall back to linear transform |
| 215 | + for node in graph_module.graph.nodes: |
| 216 | + if node.op == "call_function" and node.target in self.BATCH_NORM_OPS: |
| 217 | + if self.can_decompose_batch_norm(node, self.exported_program): |
| 218 | + self.decompose_node(node, graph_module) |
| 219 | + |
| 220 | + graph_module.print_readable() |
| 221 | + graph_module.recompile() |
| 222 | + |
| 223 | + # Propagate metadata and retrace module |
| 224 | + graph_module = super().call(graph_module).graph_module |
| 225 | + |
| 226 | + return PassResult(graph_module, True) |
0 commit comments