Skip to content

Commit 6d25df6

Browse files
authored
NXP backend: added aten.split support (#16490)
### Summary adds support for aten.split operator ### Test plan tests can be manually run using `pytest -c /dev/null backends/nxp/tests/` cc @robert-kalmar @JakeStevens @digantdesai @MartinPavella
1 parent 7793b1d commit 6d25df6

File tree

5 files changed

+555
-3
lines changed

5 files changed

+555
-3
lines changed
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2026 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Optional, TypeAlias
6+
7+
import torch
8+
from torch._subclasses import FakeTensor, FakeTensorMode
9+
from torch.fx import GraphModule, Node
10+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
11+
12+
13+
class DecomposeSplitToSlicesPass(PassBase):
14+
"""
15+
The `split` operator returns multiple tensors by partitioning `x` along `dim`. Each partitioning can be done
16+
using one `slice` operator. Replacing the `split` operator with multiple `slice` operators will yield the same results.
17+
18+
19+
20+
┌─────────────▼─────────────┐
21+
│ x │
22+
└─────────────┬─────────────┘
23+
24+
┌─────────────────────▼─────────────────────┐
25+
│ aten.split / aten.split_with_sizes │
26+
└─────────────────────┬─────────────────────┘
27+
28+
┌────────────────────┼─────────────────────────┐
29+
│ │ │
30+
┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐
31+
│ getitem(0) │ │ getitem(1) │ ... │ getitem(N-1) │
32+
└────────┬────────┘ └────────┬────────┘ └────────┬────────┘
33+
│ │ │
34+
▼ ▼ ▼
35+
out0 out1 out(N-1)
36+
37+
38+
|
39+
|
40+
replace with
41+
|
42+
|
43+
44+
45+
46+
47+
┌─────────────▼─────────────┐
48+
│ x │
49+
└─────────────┬─────────────┘
50+
51+
┌────────────────────┼─────────────────────────┐
52+
│ │ │
53+
┌────────▼────────┐ ┌────────▼────────┐ ┌────────▼────────┐
54+
│ aten.slice(x, │ │ aten.slice(x, │ ... │ (more slices) │
55+
│ dim,s0,e0 │ │ dim,s1,e1) │ ... │ │
56+
└────────┬────────┘ └────────┬────────┘ └────────┬────────┘
57+
│ │ │
58+
│ │ │
59+
▼ ▼ ▼
60+
out0 out1 outN-1
61+
62+
"""
63+
64+
graph_module: GraphModule
65+
66+
@staticmethod
67+
def _is_split_with_sizes(node: Node) -> bool:
68+
return (
69+
node.op == "call_function"
70+
and node.target == torch.ops.aten.split_with_sizes.default
71+
)
72+
73+
@staticmethod
74+
def _is_regular_split(node: Node) -> bool:
75+
is_split_tensor = (
76+
node.op == "call_function" and node.target == torch.ops.aten.split.Tensor
77+
)
78+
79+
is_split_default = (
80+
node.op == "call_function" and node.target == torch.ops.aten.split.default
81+
)
82+
83+
return is_split_tensor or is_split_default
84+
85+
def _create_slice_node(self, *slice_args) -> Node:
86+
slice_target = torch.ops.aten.slice.Tensor
87+
slice_node = self.graph_module.graph.call_function(slice_target, slice_args)
88+
89+
slice_node.meta["source_fn_stack"] = [
90+
(slice_node.name, torch.ops.aten.slice.Tensor)
91+
]
92+
93+
with FakeTensorMode() as mode:
94+
input_ = slice_args[0].meta["val"]
95+
96+
fake_input = FakeTensor.from_tensor(
97+
torch.empty(input_.shape, dtype=input_.dtype), mode
98+
)
99+
output = slice_target(fake_input, *slice_args[1:])
100+
slice_node.meta["val"] = FakeTensor.from_tensor(
101+
torch.empty(output.shape, dtype=output.dtype), mode
102+
)
103+
104+
return slice_node
105+
106+
SlicesArgs: TypeAlias = tuple[list[int], list[int], int]
107+
108+
def _get_slices_args(self, split_node: Node) -> SlicesArgs:
109+
split_nodes_chunks = split_node.meta["val"]
110+
dim = 0 if len(split_node.args) < 3 else split_node.args[2]
111+
112+
# Sometimes chunks are in tuples
113+
if isinstance(split_nodes_chunks, tuple):
114+
split_nodes_chunks = list(split_nodes_chunks)
115+
116+
if not isinstance(split_nodes_chunks, list):
117+
raise RuntimeError("Faulty split chunks")
118+
119+
# Get slices start, end params
120+
starts = []
121+
ends = []
122+
123+
curr_start = 0
124+
for s in split_nodes_chunks:
125+
starts.append(curr_start)
126+
ends.append(curr_start + s.shape[dim])
127+
curr_start += s.shape[dim]
128+
129+
return starts, ends, dim
130+
131+
def _replace_split_with_slices(self, input_node, split_node, starts, ends, dim):
132+
# Replace getitem nodes after split with slices
133+
getitem_nodes = list(split_node.users.keys())
134+
slice_nodes = []
135+
for i in range(len(starts)):
136+
slice_arguments = (input_node, dim, starts[i], ends[i])
137+
with self.graph_module.graph.inserting_after(split_node):
138+
slice_node = self._create_slice_node(*slice_arguments)
139+
slice_nodes.append(slice_node)
140+
141+
getitem_node = getitem_nodes[i]
142+
getitem_node.replace_all_uses_with(slice_node)
143+
144+
self.graph_module.graph.erase_node(getitem_node)
145+
146+
# Wire split node correctly to the input node
147+
split_node.replace_all_uses_with(input_node)
148+
self.graph_module.graph.erase_node(split_node)
149+
150+
def call(self, graph_module: GraphModule) -> Optional[PassResult]:
151+
self.graph_module = graph_module
152+
made_changes = False
153+
154+
if not any(map(self._is_regular_split, graph_module.graph.nodes)) and not any(
155+
map(self._is_split_with_sizes, graph_module.graph.nodes)
156+
):
157+
return PassResult(graph_module, made_changes)
158+
159+
for node in graph_module.graph.nodes:
160+
# Skip if not split
161+
is_split_with_sizes = self._is_split_with_sizes(node)
162+
is_regular_split = self._is_regular_split(node)
163+
164+
if not is_split_with_sizes and not is_regular_split:
165+
continue
166+
167+
# Get split args
168+
split_node = node
169+
input_node = split_node.all_input_nodes[0]
170+
split_nodes_chunks = split_node.meta["val"]
171+
172+
# Check if split is even necessary - if not, remove it
173+
if len(split_nodes_chunks) == 1:
174+
getitem_node = list(split_node.users)[0]
175+
getitem_node.replace_all_uses_with(input_node)
176+
177+
self.graph_module.graph.erase_node(getitem_node)
178+
self.graph_module.graph.erase_node(split_node)
179+
180+
made_changes = True
181+
continue
182+
183+
# Get arguments for the new slices
184+
starts, ends, dim = self._get_slices_args(split_node)
185+
186+
# Replace split with slices and restructure the graph
187+
self._replace_split_with_slices(input_node, split_node, starts, ends, dim)
188+
made_changes = True
189+
190+
self.graph_module.recompile()
191+
self.graph_module.graph.eliminate_dead_code()
192+
193+
return PassResult(self.graph_module, made_changes)

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2026 NXP
1+
# Copyright 2025-2026 NXP
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,6 +10,9 @@
1010
from executorch.backends.nxp.aten_passes.convert_unsqueeze_to_view import (
1111
ConvertUnsqueezeToViewPass,
1212
)
13+
from executorch.backends.nxp.aten_passes.decompose_split_to_slices_pass import (
14+
DecomposeSplitToSlicesPass,
15+
)
1316
from executorch.backends.nxp.aten_passes.fuse_batch_norm_with_conv_pass import (
1417
FuseBatchNormWithConvPass,
1518
)
@@ -45,6 +48,7 @@ def __init__(
4548
self, neutron_target_spec: NeutronTargetSpec, passes: list[PassType] = None
4649
):
4750
passes: list[PassType] = passes or [
51+
DecomposeSplitToSlicesPass(),
4852
FuseBatchNormWithConvPass(),
4953
FuseBatchNormWithLinearPass(),
5054
SplitGroupConvolution(),

backends/nxp/tests/models.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,38 @@ def forward(self, x):
633633
return self.activation(x)
634634

635635

636+
class GRUModel(nn.Module):
637+
def __init__(self, num_layers=1):
638+
super().__init__()
639+
self.gru = torch.nn.GRU(8, 8, num_layers=num_layers)
640+
641+
def forward(self, input_):
642+
# `input_` has shape [sequence_length, batch_size, input_size] ([8, 1, 8])
643+
return self.gru(
644+
input_, None
645+
) # The initial hidden is `None`, which will result in a `Zeros` node being added.
646+
647+
648+
class SplitWithSize(torch.nn.Module):
649+
def __init__(self, split_size, dim):
650+
super().__init__()
651+
self.split_size = split_size
652+
self.dim = dim
653+
654+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
655+
return torch.split(x, self.split_size, self.dim)
656+
657+
658+
class SplitWithSections(torch.nn.Module):
659+
def __init__(self, sections, dim):
660+
super().__init__()
661+
self.sections = sections
662+
self.dim = dim
663+
664+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]:
665+
return torch.split(x, self.sections, self.dim)
666+
667+
636668
class MiniConvNetWithRegressionHead(torch.nn.Module):
637669
def __init__(self):
638670
super().__init__()

0 commit comments

Comments
 (0)