Skip to content

Commit e254f0e

Browse files
NXP backend: Separable convolution decomposition in executorch (#13758)
### Summary Introduces an aten dialect pre-processing pass which splits `conv` nodes with `group > 1` into multiple parallel `conv` nodes with `group=1`. This replaces the original implementation in Neutron IR. ### Test plan Unit tests provided in `backends/nxp/tests/test_split_group_convolution.py`
1 parent 27a08fe commit e254f0e

File tree

7 files changed

+563
-671
lines changed

7 files changed

+563
-671
lines changed

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_linear_pass import (
1414
FuseBatchNormWithLinearPass,
1515
)
16+
from executorch.backends.nxp.aten_passes.split_group_convolution import (
17+
SplitGroupConvolution,
18+
)
1619
from executorch.exir.pass_manager import PassManager
1720
from torch import nn
1821
from torch.fx.passes.infra.pass_base import PassResult
1922

20-
PassType = list[type[Callable[[torch.fx.GraphModule], PassResult]]]
23+
PassType = type[Callable[[torch.fx.GraphModule], PassResult]]
2124

2225

2326
class NeutronAtenPassManager(PassManager):
@@ -26,6 +29,7 @@ def __init__(self, passes: list[PassType] = None):
2629
passes: list[PassType] = passes or [
2730
FuseBatchNormWithConvPass(),
2831
FuseBatchNormWithLinearPass(),
32+
SplitGroupConvolution(),
2933
]
3034

3135
super().__init__(passes)
Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
8+
import torch
9+
10+
from executorch.backends.nxp.backend.ir.converter.node_converters.shared.conv_utils import (
11+
group_conv_convertible_into_multiple_convolutions,
12+
)
13+
from torch._subclasses import FakeTensor, FakeTensorMode
14+
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
15+
from torch.export.unflatten import _assign_attr, _AttrKind
16+
from torch.fx import GraphModule, Node
17+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
18+
from torch.nn.parameter import Parameter
19+
20+
21+
class SplitGroupConvolution(PassBase):
22+
"""The eIQ Neutron NPU supports only regular and depthwise convolutions. Group convolutions must be decomposed into
23+
multiple parallel single group convolutions.
24+
Replace the nodes in the following pattern. The square brackets indicate the tensor shapes.
25+
26+
27+
│[N, Ic, ...]
28+
┌───▼───┐
29+
│ split │
30+
└┬─────┬┘
31+
┌──────────────────┘ ... └────────────────┐
32+
│[N, Ic, ...] │[N, Ic/G, ...] │[N, Ic/G, ...]
33+
┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐
34+
│ convolution ◄──W [Oc, Ic/G, ...] replace │ convolution ◄──W [Oc/G, Ic/G, ...] │ convolution ◄──W [Oc/G, Ic/G, ...]
35+
│ group=G ◄──B [Oc] ────────► │ group=1 ◄──B [Oc/G] ... │ group=1 ◄──B [Oc/G]
36+
└──────┬──────┘ with └──────┬──────┘ └──────┬──────┘
37+
▼[N, Oc, ...] │ [N, Oc/G, ...] │[N, Oc/G, ...]
38+
└──────────────────┐ ... ┌────────────────┘
39+
┌▼─────▼┐
40+
│ cat │
41+
└───┬───┘
42+
▼[N, Oc, ...]
43+
"""
44+
45+
module: GraphModule
46+
47+
def _get_tensor_constant_from_node(self, node) -> Parameter | None:
48+
"""Get the static data from a given node. If it doesn't have any data, return `None`."""
49+
if node is None or node.op != "get_attr":
50+
return None
51+
52+
target_atoms = node.target.split(".")
53+
attr_itr = self.module
54+
for atom in target_atoms:
55+
if not hasattr(attr_itr, atom):
56+
return None
57+
attr_itr = getattr(attr_itr, atom)
58+
return attr_itr
59+
60+
def _create_and_insert_get_item_node(self, input_node: Node, idx: int) -> Node:
61+
"""Create a `GetItem` node which extracts the output of `input_node` on index `idx`.
62+
The `GetItem` is also added to the graph right after the `input_node`.
63+
"""
64+
with self.module.graph.inserting_after(input_node):
65+
get_item_node = self.module.graph.create_node(
66+
"call_function",
67+
operator.getitem,
68+
(input_node, idx),
69+
{},
70+
)
71+
72+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
73+
get_item_node.meta["source_fn_stack"] = [
74+
(get_item_node.name, input_node.meta["source_fn_stack"])
75+
]
76+
get_item_node.meta["val"] = input_node.meta["val"][idx]
77+
78+
return get_item_node
79+
80+
def _create_split_node(self, *split_args) -> Node:
81+
split_target = torch.ops.aten.split.default
82+
split_node = self.module.graph.call_function(split_target, split_args)
83+
84+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
85+
split_node.meta["source_fn_stack"] = [(split_node.name, torch.split)]
86+
87+
# Compute the output shapes for the `split`, and assign the `val` meta.
88+
x_val = split_args[0].meta["val"]
89+
with FakeTensorMode() as mode:
90+
fake_input = FakeTensor.from_tensor(
91+
torch.empty(x_val.shape, dtype=x_val.dtype), mode
92+
)
93+
output_shapes = [t.shape for t in split_target(fake_input, *split_args[1:])]
94+
split_node.meta["val"] = tuple(
95+
[
96+
FakeTensor.from_tensor(torch.empty(shape, dtype=x_val.dtype), mode)
97+
for shape in output_shapes
98+
]
99+
)
100+
101+
return split_node
102+
103+
def _create_convolution_node(self, conv_target, args: tuple) -> Node:
104+
convolution_node = self.module.graph.call_function(conv_target, args)
105+
106+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
107+
convolution_node.meta["source_fn_stack"] = [
108+
(convolution_node.name, torch.convolution)
109+
]
110+
111+
# Compute the output shapes for the `convolution`, and assign the `val` meta.
112+
with FakeTensorMode() as mode:
113+
input_shapes = [
114+
input_.meta["val"].shape if hasattr(input_, "meta") else input_.shape
115+
for input_ in args[:3]
116+
]
117+
input_dtypes = [
118+
input_.meta["val"].dtype if hasattr(input_, "meta") else input_.dtype
119+
for input_ in args[:3]
120+
]
121+
fake_inputs = [
122+
FakeTensor.from_tensor(torch.empty(shape, dtype=dtype), mode)
123+
for shape, dtype in zip(input_shapes, input_dtypes)
124+
]
125+
output = conv_target(*fake_inputs, *args[3:])
126+
convolution_node.meta["val"] = FakeTensor.from_tensor(
127+
torch.empty(output.shape, dtype=output.dtype), mode
128+
)
129+
130+
return convolution_node
131+
132+
def _create_concat_node(self, *cat_args) -> Node:
133+
cat_target = torch.ops.aten.cat.default
134+
concat_node = self.module.graph.call_function(cat_target, cat_args)
135+
136+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
137+
concat_node.meta["source_fn_stack"] = [(concat_node.name, torch.cat)]
138+
139+
# Compute the output shape for the `concat`, and assign the `val` meta.
140+
with FakeTensorMode() as mode:
141+
fake_inputs = [
142+
FakeTensor.from_tensor(
143+
torch.empty(
144+
input_.meta["val"].shape, dtype=input_.meta["val"].dtype
145+
),
146+
mode,
147+
)
148+
for input_ in cat_args[0]
149+
]
150+
output = cat_target(fake_inputs, *cat_args[1:])
151+
concat_node.meta["val"] = FakeTensor.from_tensor(
152+
torch.empty(output.shape, dtype=output.dtype), mode
153+
)
154+
155+
return concat_node
156+
157+
def _get_topologically_last_node(self, nodes: list[Node]) -> Node:
158+
"""Return the node from `nodes` which appears last in the graph."""
159+
for node in reversed(self.module.graph.nodes):
160+
if node in nodes:
161+
return node
162+
163+
raise RuntimeError(f"None of the nodes `{nodes}` are in the graph.")
164+
165+
def _create_parameter_node_for_data(
166+
self, data: torch.Tensor, name: str, insert_after_node: torch.Node
167+
) -> torch.Node:
168+
"""Create a parameter node in the graph, which contains the provided `data`."""
169+
new_name = get_new_attr_name_with_prefix(name)(self.module)
170+
171+
# Create the node for the parameter.
172+
param = torch.nn.Parameter(data, False)
173+
_assign_attr(param, self.module, str(new_name), _AttrKind.PARAMETER)
174+
with self.module.graph.inserting_after(insert_after_node):
175+
static_parameter_node = self.module.graph.get_attr(new_name)
176+
177+
with FakeTensorMode() as mode:
178+
static_parameter_node.meta["val"] = FakeTensor.from_tensor(
179+
torch.empty(data.shape, dtype=data.dtype), mode
180+
)
181+
182+
return static_parameter_node
183+
184+
def call(self, module: GraphModule):
185+
self.module = module
186+
187+
def _is_conv(node_: Node):
188+
return node_.op == "call_function" and node_.target in (
189+
torch.ops.aten.conv1d.default,
190+
torch.ops.aten.conv2d.default,
191+
)
192+
193+
made_changes = False
194+
195+
for node in self.module.graph.nodes:
196+
if not _is_conv(conv_node := node):
197+
continue
198+
199+
if len(conv_node.args) < 7:
200+
# The `aten.conv` can have fewer args if the others use default values.
201+
# So in this case, `groups == 1`.
202+
continue
203+
x, w, b, stride, padding, dilation, groups = conv_node.args
204+
205+
if not group_conv_convertible_into_multiple_convolutions(conv_node, groups):
206+
continue
207+
208+
if len(x.meta["val"].shape) not in [3, 4]:
209+
# Only 1D and 2D convolutions are supported by the Neutron backend. Don't decompose anything else.
210+
continue
211+
212+
w_data = self._get_tensor_constant_from_node(w)
213+
b_data = self._get_tensor_constant_from_node(b)
214+
if w_data is None or b_data is None:
215+
continue # Only the standard case with static weights and bias is supported.
216+
217+
# Create a `split` node to split the main input.
218+
# Split across dimension `1` (channels), `groups` slices of size `input_split_size`.
219+
num_input_channels = x.meta["val"].shape[1]
220+
input_split_sizes = [num_input_channels // groups] * groups
221+
with self.module.graph.inserting_before(conv_node):
222+
split_node = self._create_split_node(x, input_split_sizes, 1)
223+
224+
# Add `GetItem` nodes to extract the outputs of the `split_node`.
225+
split_getitem_nodes = [
226+
self._create_and_insert_get_item_node(split_node, i)
227+
for i in range(groups)
228+
]
229+
230+
# Split the weights and bias, across dimension `0`, slices of size `weight_split_size`.
231+
weight_split_size = w.meta["val"].shape[0] // groups
232+
split_weights_data = torch.split(w_data, weight_split_size, 0)
233+
split_bias_data = torch.split(b_data, weight_split_size, 0)
234+
235+
# Turn the weights and biases into parameter nodes containing the data.
236+
# Use a different name for every parameter. The function internally ensures the name's uniqueness, but
237+
# relying on it sometimes causes strange failures when `groups > 5` for some weird reason.
238+
split_weight_nodes = [
239+
self._create_parameter_node_for_data(
240+
weight_data, w.name + f"_{i}_", split_node
241+
)
242+
for i, weight_data in enumerate(split_weights_data)
243+
]
244+
split_bias_nodes = [
245+
self._create_parameter_node_for_data(
246+
bias_data, b.name + f"_{i}_", split_node
247+
)
248+
for i, bias_data in enumerate(split_bias_data)
249+
]
250+
251+
# Create the `conv` nodes.
252+
with self.module.graph.inserting_after(
253+
self._get_topologically_last_node(
254+
split_getitem_nodes + split_weight_nodes + split_bias_nodes
255+
)
256+
):
257+
split_conv_nodes = [
258+
self._create_convolution_node(
259+
conv_node.target, # Use the same target as the original convolution (1d/2d/3d/...).
260+
(input_getitem, weight, bias, stride, padding, dilation, 1),
261+
)
262+
for input_getitem, weight, bias in zip(
263+
split_getitem_nodes, split_weight_nodes, split_bias_nodes
264+
)
265+
]
266+
267+
# Create the `cat` node.
268+
with self.module.graph.inserting_after(
269+
self._get_topologically_last_node(split_conv_nodes)
270+
):
271+
concat_node = self._create_concat_node(
272+
split_conv_nodes, 1
273+
) # Concatenate along the channels.
274+
275+
# Replace the uses of the original convolution with the `concat_node`.
276+
conv_node.replace_all_uses_with(concat_node)
277+
self.module.graph.erase_node(conv_node)
278+
279+
made_changes = True
280+
281+
return PassResult(self.module, made_changes)

backends/nxp/backend/ir/converter/node_converters/ops_converters/convolution_converter.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import logging
7+
68
import numpy as np
79
import torch
810

@@ -79,8 +81,9 @@ def _is_supported_on_target(
7981
return False
8082
elif conv_utils.group_conv_convertible_into_multiple_convolutions(
8183
node, groups
82-
): # Separable conv.
83-
# Requires addition of `Split` and `Concatenation` operators, which are not supported on Neutron.
84+
): # Separable conv. This should never be reached, as the node should have been decomposed into
85+
# multiple parallel convolutions by the `SplitGroupConvolution` pre-processing pass.
86+
logging.warning("Group convolution was not decomposed.")
8487
return False
8588
else: # Unexpected case (should never happen).
8689
return False
@@ -324,17 +327,8 @@ def _convert_2d_conv(
324327
elif conv_utils.group_conv_convertible_into_multiple_convolutions(
325328
t_op, conv_params.groups
326329
):
327-
# Note: by default the Group Separable Convolution is rejected by the Neutron Partitioner, see the
328-
# ConvolutionConveter._is_supported_in_IR()
329-
t_op.builtin_options = conv_2d_options.Conv2D()
330-
331-
return conv_utils.create_separated_convolutions_based_on_group(
332-
t_op,
333-
conv_params,
334-
self.builder,
335-
self._convert_unpadded_2D,
336-
conv_utils.conv_op_factory,
337-
)
330+
# This case should have been rejected in the `is_supported_on_target()` method.
331+
raise RuntimeError("Group convolution was not decomposed.")
338332

339333
else:
340334
# Convert to regular `Conv2D`.

0 commit comments

Comments
 (0)