Skip to content

Commit 5abad6c

Browse files
NXP backend: Add preprocessing pass to split multilayer GRU. (#13757)
### Summary Add pre-processing pass on the aten dialect level, which splits `gru` nodes with `num_layers > 1` into an equivalent sequence of single layer `gru` nodes. ### Test plan Unit tests provided in `backends/nxp/tests/test_gru_splitting.py`.
1 parent cac1a71 commit 5abad6c

File tree

3 files changed

+569
-0
lines changed

3 files changed

+569
-0
lines changed

backends/nxp/aten_passes/neutron_aten_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from executorch.backends.nxp.aten_passes.split_group_convolution import (
1717
SplitGroupConvolution,
1818
)
19+
from executorch.backends.nxp.aten_passes.split_gru_based_on_num_layers import (
20+
SplitGRUBasedOnNumLayers,
21+
)
1922
from executorch.exir.pass_manager import PassManager
2023
from torch import nn
2124
from torch.fx.passes.infra.pass_base import PassResult
@@ -30,6 +33,7 @@ def __init__(self, passes: list[PassType] = None):
3033
FuseBatchNormWithConvPass(),
3134
FuseBatchNormWithLinearPass(),
3235
SplitGroupConvolution(),
36+
SplitGRUBasedOnNumLayers(),
3337
]
3438

3539
super().__init__(passes)
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import operator
7+
8+
import torch
9+
from torch._subclasses import FakeTensor, FakeTensorMode
10+
from torch.fx import GraphModule, Node
11+
from torch.fx.passes.infra.pass_base import PassBase, PassResult
12+
13+
14+
class SplitGRUBasedOnNumLayers(PassBase):
15+
"""Replace an `aten.gru.input` operator with `num_layers > 1` with a subgraph consisting of multiple chained
16+
`aten.gru.input` operators with `num_layers == 1`, according to the following schematic.
17+
18+
19+
X H_h
20+
│ │
21+
│ ┌───▼───┐
22+
│ │ Split │
23+
│ └─┬───┬─┘
24+
│ ┌───────┘ └───────┐
25+
│ ┌─────▼──────┐ ┌──────▼──────────────────┐
26+
│ │ GetItem[0] │ ... │ GetItem[<num_layers>-1] │
27+
│ └─────┬──────┘ └──────────────┬──────────┘
28+
X X_h └────┐ │ │
29+
│ │ ┌─────▼──▼─────┐ │
30+
┌──────────▼───▼──────────┐ W11─► GRU ◄─B11 │
31+
W11─► ◄─B11 W12─► num_layers=1 ◄─B12 │
32+
W12─► GRU ◄─B12 └───────┬──────┘ │
33+
... │ num_layers=<num_layers> │ ... ┌───────┴───────┐ │
34+
W<num_layers>2─► ◄─B<num_layers>2 ┌─────▼──────┐ ┌──────▼─────┐ │
35+
└────────────┬────────────┘ │ GetItem[0] │ │ GetItem[1] │ │
36+
┌────────┴────────┐ └─────┬──────┘ └──────┬─────┘ │
37+
┌─────▼──────┐ ┌──────▼─────┐ replace with │ ┌──────┘ │
38+
│ GetItem[0] │ │ GetItem[1] │ ─────────► └────────┼─────── ... ────────────┐ │
39+
└─────┬──────┘ └──────┬─────┘ │ ┌─────▼──▼─────┐
40+
▼ ▼ │ W<num_layers>1─► GRU ◄─B<num_layers>1
41+
Y Y_h │ W<num_layers>2─► num_layers=1 ◄─B<num_layers>2
42+
│ └──────┬───────┘
43+
│ ┌───────┴───────┐
44+
│ ┌─────▼──────┐ ┌──────▼─────┐
45+
│ │ GetItem[0] │ │ GetItem[1] │
46+
│ └─────┬──────┘ └──────┬─────┘
47+
│ ... ┌──────────┼───────────────┘
48+
┌▼──────▼┐ │
49+
│ Concat │ │
50+
└───┬────┘ │
51+
▼ ▼
52+
Y_h Y
53+
54+
The `aten.gru.input` has the following schema:
55+
aten::gru.input(
56+
Tensor input, Tensor hx, Tensor[] params,
57+
bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first
58+
) -> (Tensor, Tensor)
59+
"""
60+
61+
module: GraphModule
62+
63+
def _get_topologically_last_node(self, nodes: list[Node]) -> Node:
64+
"""Return the node from `nodes` which appears last in the graph."""
65+
for node in reversed(self.module.graph.nodes):
66+
if node in nodes:
67+
return node
68+
69+
raise RuntimeError(f"None of the nodes `{nodes}` are in the graph.")
70+
71+
def _create_and_insert_get_item_node(self, input_node: Node, idx: int) -> Node:
72+
"""Create a `GetItem` node which extracts the output of `input_node` on index `idx`.
73+
The `GetItem` is also added to the graph right after the `input_node`.
74+
"""
75+
with self.module.graph.inserting_after(input_node):
76+
get_item_node = self.module.graph.create_node(
77+
"call_function",
78+
operator.getitem,
79+
(input_node, idx),
80+
{},
81+
)
82+
83+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
84+
get_item_node.meta["source_fn_stack"] = [
85+
(get_item_node.name, input_node.meta["source_fn_stack"])
86+
]
87+
get_item_node.meta["val"] = input_node.meta["val"][idx]
88+
89+
return get_item_node
90+
91+
def _create_gru_node(self, *gru_args) -> Node:
92+
"""Create an `aten.gru.input` node with the provided arguments. The node will NOT be added to the graph
93+
automatically.
94+
95+
:param gru_args: Arguments for the `aten.gru.input` operation.
96+
:return: The created GRU Node.
97+
"""
98+
gru_target = torch.ops.aten.gru.input
99+
gru_node = self.module.graph.call_function(gru_target, gru_args)
100+
101+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
102+
gru_node.meta["source_fn_stack"] = [(gru_node.name, torch.nn.modules.rnn.GRU)]
103+
104+
# Compute the shapes of the GRU outputs, and assign the `val` meta.
105+
x_val, h_val = gru_args[0].meta["val"], gru_args[1].meta["val"]
106+
with FakeTensorMode() as mode:
107+
fake_x = FakeTensor.from_tensor(
108+
torch.empty(x_val.shape, dtype=x_val.dtype), mode
109+
)
110+
fake_h = FakeTensor.from_tensor(
111+
torch.empty(h_val.shape, dtype=h_val.dtype), mode
112+
)
113+
fake_weights = [
114+
FakeTensor.from_tensor(
115+
torch.empty(w.meta["val"].shape, dtype=x_val.dtype), mode
116+
)
117+
for w in gru_args[2]
118+
]
119+
output_shapes = [
120+
t.shape for t in gru_target(fake_x, fake_h, fake_weights, *gru_args[3:])
121+
]
122+
gru_node.meta["val"] = tuple(
123+
[
124+
FakeTensor.from_tensor(torch.empty(shape, dtype=h_val.dtype), mode)
125+
for shape in output_shapes
126+
]
127+
)
128+
129+
return gru_node
130+
131+
def _create_split_node(self, *split_args) -> Node:
132+
"""Create an `aten.split.default` node with the provided arguments. The node will NOT be added to the graph
133+
automatically.
134+
135+
:param split_args: Arguments for the `aten.split.default` operation.
136+
:return: The created Split Node.
137+
"""
138+
split_target = torch.ops.aten.split.default
139+
split_node = self.module.graph.call_function(split_target, split_args)
140+
141+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
142+
split_node.meta["source_fn_stack"] = [(split_node.name, torch.split)]
143+
144+
# Compute the output shapes for the `split`, and assign the `val` meta.
145+
x_val = split_args[0].meta["val"]
146+
with FakeTensorMode() as mode:
147+
fake_input = FakeTensor.from_tensor(
148+
torch.empty(x_val.shape, dtype=x_val.dtype), mode
149+
)
150+
output_shapes = [t.shape for t in split_target(fake_input, *split_args[1:])]
151+
split_node.meta["val"] = tuple(
152+
[
153+
FakeTensor.from_tensor(torch.empty(shape, dtype=x_val.dtype), mode)
154+
for shape in output_shapes
155+
]
156+
)
157+
158+
return split_node
159+
160+
def create_concat_node(self, *cat_args) -> Node:
161+
"""Create an `aten.cat.default` node with the provided arguments. The node will NOT be added to the graph
162+
automatically.
163+
164+
:param cat_args: Arguments for the `aten.cat.default` operation.
165+
:return: The created Cat Node.
166+
"""
167+
cat_target = torch.ops.aten.cat.default
168+
cat_node = self.module.graph.call_function(cat_target, cat_args)
169+
170+
# Assign the `source_fn_stack` and `val` meta fields as they are required for quantization.
171+
cat_node.meta["source_fn_stack"] = [(cat_node.name, torch.cat)]
172+
173+
# Compute the output shape for the `concat`, and assign the `val` meta.
174+
with FakeTensorMode() as mode:
175+
fake_inputs = [
176+
FakeTensor.from_tensor(
177+
torch.empty(
178+
input_.meta["val"].shape, dtype=input_.meta["val"].dtype
179+
),
180+
mode,
181+
)
182+
for input_ in cat_args[0]
183+
]
184+
output = cat_target(fake_inputs, *cat_args[1:])
185+
cat_node.meta["val"] = FakeTensor.from_tensor(
186+
torch.empty(output.shape, dtype=output.dtype), mode
187+
)
188+
189+
return cat_node
190+
191+
def call(self, module: GraphModule) -> PassResult:
192+
self.module = module
193+
made_changes = False
194+
195+
def _is_gru(node_: Node) -> bool:
196+
return (
197+
node_.op == "call_function" and node_.target == torch.ops.aten.gru.input
198+
)
199+
200+
if not any(map(_is_gru, module.graph.nodes)):
201+
return PassResult(module, False) # No GRU nodes in the model.
202+
203+
for node in module.graph.nodes:
204+
if not _is_gru(node):
205+
continue # Not GRU.
206+
207+
original_gru_node = node
208+
if (num_layers := original_gru_node.args[4]) == 1:
209+
# Basic 1-layer GRU.
210+
continue
211+
212+
if (dropout := node.args[5]) != 0.0 or (train := node.args[6]):
213+
# Conversion for these cases is not supported, so the pre-processing should not be applied.
214+
continue
215+
216+
# The `hx` (initial hidden state) has shape:
217+
# [D * num_layers, hidden_size] or [D * num_layers, batch_size, hidden_size]
218+
# where D = 2 if bidirectional else 1.
219+
# Split the `hx` into `num_layers` different tensors.
220+
h_x = original_gru_node.args[1]
221+
bidirectional = original_gru_node.args[7]
222+
d = 2 if bidirectional else 1
223+
with module.graph.inserting_before(original_gru_node):
224+
# Split across the dimension `0`. Slices of size `d`.
225+
num_slices = h_x.meta["val"].shape[0] // d
226+
split_node = self._create_split_node(h_x, [d] * num_slices, 0)
227+
228+
# Add `GetItem` nodes to extract the outputs of the `split_node`.
229+
h_0_get_item_nodes = [
230+
self._create_and_insert_get_item_node(split_node, i)
231+
for i in range(num_layers)
232+
]
233+
234+
# ---- Create new GRU nodes ----
235+
236+
all_weights = original_gru_node.args[2]
237+
has_biases = original_gru_node.args[3]
238+
# The `all_weights` list contains
239+
# [w11, w12, b11, b12, w21, w22, b21, b22, ...] if `has_biases` else [w11, w12, w21, w22, ...].
240+
step = 4 if has_biases else 2
241+
if bidirectional:
242+
# Every other set of weights and biases (2 or 4) represents the reverse connections for the layer.
243+
step *= 2
244+
245+
gru_nodes = []
246+
batch_first = original_gru_node.args[-1]
247+
248+
# The `GetItem` node which extracts the main output (y) of the previous GRU. (Or the main input for the
249+
# first GRU).
250+
prev_gru_main_output_get_item = original_gru_node.args[0]
251+
output_h_get_item_nodes = (
252+
[]
253+
) # `GetItem` nodes which extract the output hidden states of the GRU nodes.
254+
for i in range(num_layers):
255+
current_gru_weights = tuple(all_weights[step * i : step * (i + 1)])
256+
257+
# Select the node to insert the new `GRU` after.
258+
prev_node = (
259+
self._get_topologically_last_node(h_0_get_item_nodes)
260+
if i == 0
261+
else prev_gru_main_output_get_item
262+
)
263+
264+
# Create the new `GRU`.
265+
with module.graph.inserting_after(prev_node):
266+
gru = self._create_gru_node(
267+
prev_gru_main_output_get_item,
268+
h_0_get_item_nodes[i],
269+
current_gru_weights,
270+
has_biases,
271+
1, # New `num_layers`.
272+
dropout,
273+
train,
274+
bidirectional,
275+
batch_first,
276+
)
277+
gru_nodes.append(gru)
278+
279+
# Create the `GetItem` nodes to extract the outputs of the `GRU`.
280+
prev_gru_main_output_get_item = self._create_and_insert_get_item_node(
281+
gru_nodes[i], 0
282+
)
283+
output_h_get_item_nodes.append(
284+
self._create_and_insert_get_item_node(gru_nodes[i], 1)
285+
)
286+
287+
# Add a `Concat` to collect all the output hidden states.
288+
with module.graph.inserting_after(prev_gru_main_output_get_item):
289+
concat_node = self.create_concat_node(
290+
output_h_get_item_nodes, 0 # Concatenate along the dimension `0`.
291+
)
292+
293+
# Replace the uses of the original `GRU` outputs with the new corresponding outputs.
294+
original_y_get_item, original_yh_get_item = list(
295+
original_gru_node.users.keys()
296+
)
297+
original_y_get_item.replace_all_uses_with(prev_gru_main_output_get_item)
298+
original_yh_get_item.replace_all_uses_with(concat_node)
299+
300+
# Remove the old nodes.
301+
module.graph.erase_node(original_y_get_item)
302+
module.graph.erase_node(original_yh_get_item)
303+
module.graph.erase_node(original_gru_node)
304+
305+
made_changes = True
306+
307+
return PassResult(module, made_changes)

0 commit comments

Comments
 (0)