|
| 1 | +# Copyright 2025 NXP |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +import operator |
| 7 | + |
| 8 | +import torch |
| 9 | + |
| 10 | +from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import ( |
| 11 | + group_conv_convertible_into_multiple_convolutions, |
| 12 | +) |
| 13 | +from torch._subclasses import FakeTensor, FakeTensorMode |
| 14 | +from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix |
| 15 | +from torch.export.unflatten import _assign_attr, _AttrKind |
| 16 | +from torch.fx import GraphModule, Node |
| 17 | +from torch.fx.passes.infra.pass_base import PassBase, PassResult |
| 18 | +from torch.nn.parameter import Parameter |
| 19 | + |
| 20 | + |
| 21 | +class SplitGroupConvolution(PassBase): |
| 22 | + """The eIQ Neutron NPU supports only regular and depthwise convolutions. Group convolutions must be decomposed into |
| 23 | + multiple parallel single group convolutions. |
| 24 | + Replace the nodes in the following pattern. The square brackets indicate the tensor shapes. |
| 25 | +
|
| 26 | +
|
| 27 | + │[N, Ic, ...] |
| 28 | + ┌───▼───┐ |
| 29 | + │ split │ |
| 30 | + └┬─────┬┘ |
| 31 | + ┌──────────────────┘ ... └────────────────┐ |
| 32 | + │[N, Ic, ...] │[N, Ic/G, ...] │[N, Ic/G, ...] |
| 33 | + ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ |
| 34 | + │ convolution ◄──W [Oc, Ic/G, ...] replace │ convolution ◄──W [Oc/G, Ic/G, ...] │ convolution ◄──W [Oc/G, Ic/G, ...] |
| 35 | + │ group=G ◄──B [Oc] ────────► │ group=1 ◄──B [Oc/G] ... │ group=1 ◄──B [Oc/G] |
| 36 | + └──────┬──────┘ with └──────┬──────┘ └──────┬──────┘ |
| 37 | + ▼[N, Oc, ...] │ [N, Oc/G, ...] │[N, Oc/G, ...] |
| 38 | + └──────────────────┐ ... ┌────────────────┘ |
| 39 | + ┌▼─────▼┐ |
| 40 | + │ cat │ |
| 41 | + └───┬───┘ |
| 42 | + ▼[N, Oc, ...] |
| 43 | + """ |
| 44 | + |
| 45 | + module: GraphModule |
| 46 | + |
| 47 | + def _get_tensor_constant_from_node(self, node) -> Parameter | None: |
| 48 | + """Get the static data from a given node. If it doesn't have any data, return `None`.""" |
| 49 | + if node is None or node.op != "get_attr": |
| 50 | + return None |
| 51 | + |
| 52 | + target_atoms = node.target.split(".") |
| 53 | + attr_itr = self.module |
| 54 | + for atom in target_atoms: |
| 55 | + if not hasattr(attr_itr, atom): |
| 56 | + return None |
| 57 | + attr_itr = getattr(attr_itr, atom) |
| 58 | + return attr_itr |
| 59 | + |
| 60 | + def _create_and_insert_get_item_node(self, input_node: Node, idx: int) -> Node: |
| 61 | + """Create a `GetItem` node which extracts the output of `input_node` on index `idx`. |
| 62 | + The `GetItem` is also added to the graph right after the `input_node`. |
| 63 | + """ |
| 64 | + with self.module.graph.inserting_after(input_node): |
| 65 | + get_item_node = self.module.graph.create_node( |
| 66 | + "call_function", |
| 67 | + operator.getitem, |
| 68 | + (input_node, idx), |
| 69 | + {}, |
| 70 | + ) |
| 71 | + |
| 72 | + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. |
| 73 | + get_item_node.meta["source_fn_stack"] = [ |
| 74 | + (get_item_node.name, input_node.meta["source_fn_stack"]) |
| 75 | + ] |
| 76 | + get_item_node.meta["val"] = input_node.meta["val"][idx] |
| 77 | + |
| 78 | + return get_item_node |
| 79 | + |
| 80 | + def _create_split_node(self, *split_args) -> Node: |
| 81 | + split_target = torch.ops.aten.split.default |
| 82 | + split_node = self.module.graph.call_function(split_target, split_args) |
| 83 | + |
| 84 | + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. |
| 85 | + split_node.meta["source_fn_stack"] = [(split_node.name, torch.split)] |
| 86 | + |
| 87 | + # Compute the output shapes for the `split`, and assign the `val` meta. |
| 88 | + x_val = split_args[0].meta["val"] |
| 89 | + with FakeTensorMode() as mode: |
| 90 | + fake_input = FakeTensor.from_tensor( |
| 91 | + torch.empty(x_val.shape, dtype=x_val.dtype), mode |
| 92 | + ) |
| 93 | + output_shapes = [t.shape for t in split_target(fake_input, *split_args[1:])] |
| 94 | + split_node.meta["val"] = tuple( |
| 95 | + [ |
| 96 | + FakeTensor.from_tensor(torch.empty(shape, dtype=x_val.dtype), mode) |
| 97 | + for shape in output_shapes |
| 98 | + ] |
| 99 | + ) |
| 100 | + |
| 101 | + return split_node |
| 102 | + |
| 103 | + def _create_convolution_node(self, conv_target, args: tuple) -> Node: |
| 104 | + convolution_node = self.module.graph.call_function(conv_target, args) |
| 105 | + |
| 106 | + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. |
| 107 | + convolution_node.meta["source_fn_stack"] = [ |
| 108 | + (convolution_node.name, torch.convolution) |
| 109 | + ] |
| 110 | + |
| 111 | + # Compute the output shapes for the `convolution`, and assign the `val` meta. |
| 112 | + with FakeTensorMode() as mode: |
| 113 | + input_shapes = [ |
| 114 | + input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape |
| 115 | + for input_ in args[:3] |
| 116 | + ] |
| 117 | + input_dtypes = [ |
| 118 | + input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype |
| 119 | + for input_ in args[:3] |
| 120 | + ] |
| 121 | + fake_inputs = [ |
| 122 | + FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode) |
| 123 | + for shape, dtype in zip(input_shapes, input_dtypes) |
| 124 | + ] |
| 125 | + output = conv_target(*fake_inputs, *args[3:]) |
| 126 | + convolution_node.meta["val"] = FakeTensor.from_tensor( |
| 127 | + torch.empty(output.shape, dtype=output.dtype), mode |
| 128 | + ) |
| 129 | + |
| 130 | + return convolution_node |
| 131 | + |
| 132 | + def _create_concat_node(self, *cat_args) -> Node: |
| 133 | + cat_target = torch.ops.aten.cat.default |
| 134 | + concat_node = self.module.graph.call_function(cat_target, cat_args) |
| 135 | + |
| 136 | + # Assign the `source_fn_stack` and `val` meta fields as they are required for quantization. |
| 137 | + concat_node.meta["source_fn_stack"] = [(concat_node.name, torch.cat)] |
| 138 | + |
| 139 | + # Compute the output shape for the `concat`, and assign the `val` meta. |
| 140 | + with FakeTensorMode() as mode: |
| 141 | + fake_inputs = [ |
| 142 | + FakeTensor.from_tensor( |
| 143 | + torch.empty( |
| 144 | + input_.meta["val"].shape, dtype=input_.meta["val"].dtype |
| 145 | + ), |
| 146 | + mode, |
| 147 | + ) |
| 148 | + for input_ in cat_args[0] |
| 149 | + ] |
| 150 | + output = cat_target(fake_inputs, *cat_args[1:]) |
| 151 | + concat_node.meta["val"] = FakeTensor.from_tensor( |
| 152 | + torch.empty(output.shape, dtype=output.dtype), mode |
| 153 | + ) |
| 154 | + |
| 155 | + return concat_node |
| 156 | + |
| 157 | + def _get_topologically_last_node(self, nodes: list[Node]) -> Node: |
| 158 | + """Return the node from `nodes` which appears last in the graph.""" |
| 159 | + for node in reversed(self.module.graph.nodes): |
| 160 | + if node in nodes: |
| 161 | + return node |
| 162 | + |
| 163 | + raise RuntimeError(f"None of the nodes `{nodes}` are in the graph.") |
| 164 | + |
| 165 | + def _create_parameter_node_for_data( |
| 166 | + self, data: torch.Tensor, name: str, insert_after_node: torch.Node |
| 167 | + ) -> torch.Node: |
| 168 | + """Create a parameter node in the graph, which contains the provided `data`.""" |
| 169 | + new_name = get_new_attr_name_with_prefix(name)(self.module) |
| 170 | + |
| 171 | + # Create the node for the parameter. |
| 172 | + param = torch.nn.Parameter(data, False) |
| 173 | + _assign_attr(param, self.module, str(new_name), _AttrKind.PARAMETER) |
| 174 | + with self.module.graph.inserting_after(insert_after_node): |
| 175 | + static_parameter_node = self.module.graph.get_attr(new_name) |
| 176 | + |
| 177 | + with FakeTensorMode() as mode: |
| 178 | + static_parameter_node.meta["val"] = FakeTensor.from_tensor( |
| 179 | + torch.empty(data.shape, dtype=data.dtype), mode |
| 180 | + ) |
| 181 | + |
| 182 | + return static_parameter_node |
| 183 | + |
| 184 | + def call(self, module: GraphModule): |
| 185 | + self.module = module |
| 186 | + |
| 187 | + def _is_conv(node_: Node): |
| 188 | + return node_.op == "call_function" and node_.target in ( |
| 189 | + torch.ops.aten.conv1d.default, |
| 190 | + torch.ops.aten.conv2d.default, |
| 191 | + ) |
| 192 | + |
| 193 | + made_changes = False |
| 194 | + |
| 195 | + for node in self.module.graph.nodes: |
| 196 | + if not _is_conv(conv_node := node): |
| 197 | + continue |
| 198 | + |
| 199 | + if len(conv_node.args) < 7: |
| 200 | + # The `aten.conv` can have fewer args if the others use default values. |
| 201 | + # So in this case, `groups == 1`. |
| 202 | + continue |
| 203 | + x, w, b, stride, padding, dilation, groups = conv_node.args |
| 204 | + |
| 205 | + if not group_conv_convertible_into_multiple_convolutions(conv_node, groups): |
| 206 | + continue |
| 207 | + |
| 208 | + if len(x.meta["val"].shape) not in [3, 4]: |
| 209 | + # Only 1D and 2D convolutions are supported by the Neutron backend. Don't decompose anything else. |
| 210 | + continue |
| 211 | + |
| 212 | + w_data = self._get_tensor_constant_from_node(w) |
| 213 | + b_data = self._get_tensor_constant_from_node(b) |
| 214 | + if w_data is None or b_data is None: |
| 215 | + continue # Only the standard case with static weights and bias is supported. |
| 216 | + |
| 217 | + # Create a `split` node to split the main input. |
| 218 | + # Split across dimension `1` (channels), `groups` slices of size `input_split_size`. |
| 219 | + num_input_channels = x.meta["val"].shape[1] |
| 220 | + input_split_sizes = [num_input_channels // groups] * groups |
| 221 | + with self.module.graph.inserting_before(conv_node): |
| 222 | + split_node = self._create_split_node(x, input_split_sizes, 1) |
| 223 | + |
| 224 | + # Add `GetItem` nodes to extract the outputs of the `split_node`. |
| 225 | + split_getitem_nodes = [ |
| 226 | + self._create_and_insert_get_item_node(split_node, i) |
| 227 | + for i in range(groups) |
| 228 | + ] |
| 229 | + |
| 230 | + # Split the weights and bias, across dimension `0`, slices of size `weight_split_size`. |
| 231 | + weight_split_size = w.meta["val"].shape[0] // groups |
| 232 | + split_weights_data = torch.split(w_data, weight_split_size, 0) |
| 233 | + split_bias_data = torch.split(b_data, weight_split_size, 0) |
| 234 | + |
| 235 | + # Turn the weights and biases into parameter nodes containing the data. |
| 236 | + # Use a different name for every parameter. The function internally ensures the name's uniqueness, but |
| 237 | + # relying on it sometimes causes strange failures when `groups > 5` for some weird reason. |
| 238 | + split_weight_nodes = [ |
| 239 | + self._create_parameter_node_for_data( |
| 240 | + weight_data, w.name + f"_{i}_", split_node |
| 241 | + ) |
| 242 | + for i, weight_data in enumerate(split_weights_data) |
| 243 | + ] |
| 244 | + split_bias_nodes = [ |
| 245 | + self._create_parameter_node_for_data( |
| 246 | + bias_data, b.name + f"_{i}_", split_node |
| 247 | + ) |
| 248 | + for i, bias_data in enumerate(split_bias_data) |
| 249 | + ] |
| 250 | + |
| 251 | + # Create the `conv` nodes. |
| 252 | + with self.module.graph.inserting_after( |
| 253 | + self._get_topologically_last_node( |
| 254 | + split_getitem_nodes + split_weight_nodes + split_bias_nodes |
| 255 | + ) |
| 256 | + ): |
| 257 | + split_conv_nodes = [ |
| 258 | + self._create_convolution_node( |
| 259 | + conv_node.target, # Use the same target as the original convolution (1d/2d/3d/...). |
| 260 | + (input_getitem, weight, bias, stride, padding, dilation, 1), |
| 261 | + ) |
| 262 | + for input_getitem, weight, bias in zip( |
| 263 | + split_getitem_nodes, split_weight_nodes, split_bias_nodes |
| 264 | + ) |
| 265 | + ] |
| 266 | + |
| 267 | + # Create the `cat` node. |
| 268 | + with self.module.graph.inserting_after( |
| 269 | + self._get_topologically_last_node(split_conv_nodes) |
| 270 | + ): |
| 271 | + concat_node = self._create_concat_node( |
| 272 | + split_conv_nodes, 1 |
| 273 | + ) # Concatenate along the channels. |
| 274 | + |
| 275 | + # Replace the uses of the original convolution with the `concat_node`. |
| 276 | + conv_node.replace_all_uses_with(concat_node) |
| 277 | + self.module.graph.erase_node(conv_node) |
| 278 | + |
| 279 | + made_changes = True |
| 280 | + |
| 281 | + return PassResult(self.module, made_changes) |
0 commit comments