Skip to content

Commit 7dd981f

Browse files
ArmBackend: Add support for Conv1D
Differential Revision: D64943666 Pull Request resolved: pytorch#6453
1 parent db38bcc commit 7dd981f

File tree

8 files changed

+604
-25
lines changed

8 files changed

+604
-25
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
AnnotateChannelsLastDimOrder,
1313
)
1414
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
15+
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
1516
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
1617
ConvertExpandCopyToRepeatPass,
1718
)
@@ -69,6 +70,7 @@ def transform_to_backend_pipeline(
6970
self.add_pass(DecomposeDivPass())
7071
self.add_pass(InsertSqueezeAfterSumPass())
7172
self.add_pass(ConvertSplitToSlicePass())
73+
self.add_pass(Conv1dUnsqueezePass(exported_program))
7274
self.add_pass(DecomposeSoftmaxesPass())
7375
for spec in compile_spec:
7476
if spec.key == "permute_memory_format":

backends/arm/_passes/arm_pass_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
12
# Copyright 2024 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
@@ -9,11 +10,57 @@
910
import torch
1011
import torch.fx
1112

13+
from executorch.exir import ExportedProgram
1214
from executorch.exir.dialects._ops import ops as exir_ops
15+
16+
from torch._export.utils import (
17+
get_buffer,
18+
get_lifted_tensor_constant,
19+
get_param,
20+
is_buffer,
21+
is_lifted_tensor_constant,
22+
is_param,
23+
)
1324
from torch._ops import OpOverload
1425
from torch._subclasses.fake_tensor import FakeTensor
1526

1627

28+
def is_get_attr_node(node: torch.fx.Node) -> bool:
29+
"""
30+
Returns true if the given node is a get attr node for a tensor of the model
31+
"""
32+
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
33+
34+
35+
def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
36+
return (
37+
is_get_attr_node(node)
38+
or is_param(exp_prog, node)
39+
or is_buffer(exp_prog, node)
40+
or is_lifted_tensor_constant(exp_prog, node)
41+
)
42+
43+
44+
def get_param_tensor(
45+
exp_prog: ExportedProgram, node: torch.fx.Node
46+
) -> Optional[torch.Tensor]:
47+
if node is None:
48+
return None
49+
elif is_param(exp_prog, node):
50+
return get_param(exp_prog, node)
51+
elif is_buffer(exp_prog, node):
52+
return get_buffer(exp_prog, node)
53+
elif is_lifted_tensor_constant(exp_prog, node):
54+
return get_lifted_tensor_constant(exp_prog, node)
55+
elif is_get_attr_node(node):
56+
# This is a hack to support both lifted and unlifted graph
57+
try:
58+
return getattr(node.graph.owning_module, node.target)
59+
except AttributeError:
60+
return getattr(exp_prog.graph_module, node.target)
61+
raise RuntimeError(f"unsupported param type, {node.op}.")
62+
63+
1764
def create_node(
1865
graph: torch.fx.Graph,
1966
op_target: OpOverload,
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2024 Arm Limited and/or its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import (
11+
create_node,
12+
get_param_tensor,
13+
insert_q_dq_pair,
14+
is_param_node,
15+
)
16+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
from executorch.exir.pass_base import ExportPass, PassResult
20+
21+
22+
class Conv1dUnsqueezePass(ExportPass):
23+
"""
24+
This pass is used to change conv1d ops into conv2d since TOSA only
25+
supports 2d and 3d convolution. This is done by modifying the graph to do the
26+
following:
27+
1) unsqueeze the convolution's input from 3d to 4d
28+
2) if the input to unsqueeze is quantized, insert q/dq-pair after unsqueeze
29+
3) perform a conv2d (with a modified version of the original conv1d args)
30+
4) squeeze the output back down to 3d.
31+
5) if all users of squeeze are quantized, insert q/dq-pair before squeeze
32+
"""
33+
34+
def __init__(self, exported_program: ExportedProgram) -> None:
35+
super().__init__()
36+
self.exported_program = exported_program
37+
38+
def unsqueeze_kernel_weights(self, kernel_node):
39+
"""
40+
Unsqueezes the weights of a conv1d to make it 4 dimensional.
41+
42+
Args:
43+
kernel_node: the weights of conv1d node to be unsqueezed
44+
"""
45+
kernel_param_3d = get_param_tensor(self.exported_program, kernel_node)
46+
if kernel_param_3d is None:
47+
raise AssertionError("Expected param tensor for the kernel node")
48+
49+
kernel_param_4d = torch.nn.Parameter(
50+
data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1),
51+
requires_grad=False,
52+
)
53+
54+
if torch._export.utils.is_param(self.exported_program, kernel_node):
55+
parameter_name = self.exported_program.graph_signature.inputs_to_parameters[
56+
kernel_node.name
57+
]
58+
self.exported_program.state_dict[parameter_name] = kernel_param_4d
59+
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
60+
elif torch._export.utils.is_buffer(self.exported_program, kernel_node):
61+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
62+
kernel_node.name
63+
]
64+
self.exported_program.state_dict[buffer_name] = kernel_param_4d
65+
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
66+
elif torch._export.utils.is_lifted_tensor_constant(
67+
self.exported_program, kernel_node
68+
):
69+
buffer_name = (
70+
self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[
71+
kernel_node.name
72+
]
73+
)
74+
self.exported_program.constants[buffer_name] = kernel_param_4d
75+
kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1)
76+
else:
77+
setattr(
78+
kernel_node.graph.owning_module,
79+
kernel_node.target,
80+
kernel_param_4d,
81+
)
82+
83+
def call(self, graph_module: torch.fx.GraphModule):
84+
graph = graph_module.graph
85+
node_list = list(graph.nodes)
86+
for node in node_list:
87+
if node.op == "call_function":
88+
if node.target == exir_ops.edge.aten.convolution.default:
89+
stride = list(node.args[3])
90+
if len(stride) != 1:
91+
# skip conv if it is not 1d
92+
continue
93+
94+
kernel_node = node.args[1]
95+
if kernel_node.target == dq_op:
96+
kernel_node = kernel_node.args[0]
97+
98+
if not is_param_node(self.exported_program, kernel_node):
99+
raise AssertionError(
100+
"Expected op for convolution weight node to be a get_attr node or a parameter"
101+
)
102+
103+
# Modify graph such that the conv changes from 1d to 2d
104+
self.unsqueeze_kernel_weights(kernel_node)
105+
106+
# (b) Extend stride, padding, and dilation for extra dim
107+
node.args = (
108+
node.args[0],
109+
node.args[1],
110+
node.args[2],
111+
node.args[3] + [1], # stride
112+
node.args[4] + [0], # padding
113+
node.args[5] + [1], # dilation
114+
node.args[6],
115+
node.args[7] + [0],
116+
node.args[8],
117+
)
118+
119+
# c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d)
120+
# unsqueeze -> conv2d -> squeeze
121+
with graph.inserting_before(node):
122+
input_node = node.args[0]
123+
unsqueeze_before = create_node(
124+
graph, exir_ops.edge.aten.unsqueeze_copy.default
125+
)
126+
unsqueeze_before.args = (
127+
input_node, # Input is node's original input
128+
-1, # Last Dimension
129+
)
130+
node.replace_input_with(input_node, unsqueeze_before)
131+
132+
# If Quantized we must insert unsqueeze --> q --> dq --> node
133+
if input_node.target == dq_op:
134+
q_params = input_node.args[1:]
135+
insert_q_dq_pair(graph, unsqueeze_before, q_params)
136+
137+
with graph.inserting_after(node):
138+
squeeze_after = create_node(
139+
graph,
140+
exir_ops.edge.aten.squeeze_copy.dims,
141+
)
142+
squeeze_after.args = (
143+
node, # Input is the conv node
144+
[-1], # Last dimension
145+
)
146+
original_users = [
147+
user for user in node.users if user != squeeze_after
148+
]
149+
for user in original_users:
150+
user.replace_input_with(node, squeeze_after)
151+
152+
# If quantized, insert conv2d --> q --> dq --> squeeze
153+
if all(
154+
original_user.target == q_op for original_user in original_users
155+
):
156+
q_params = original_users[0].args[1:]
157+
insert_q_dq_pair(graph, node, q_params)
158+
159+
graph_module.recompile()
160+
# Since we are overriding "call", we need to call the parent's "call"
161+
# to retrace the graph and regenerate metadata
162+
graph_module = super().call(graph_module).graph_module
163+
164+
return PassResult(graph_module, True)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,5 @@ def call(self, graph_module: torch.fx.GraphModule):
7070
output_node.replace_all_uses_with(slice_node)
7171
graph.eliminate_dead_code()
7272
graph_module.recompile()
73+
graph_module = super().call(graph_module).graph_module
7374
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)