|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. |
2 | | -# Copyright 2024 Arm Limited and/or its affiliates. |
3 | 2 | # All rights reserved. |
| 3 | +# Copyright 2024-2025 Arm Limited and/or its affiliates. |
4 | 4 | # |
5 | 5 | # This source code is licensed under the BSD-style license found in the |
6 | 6 | # LICENSE file in the root directory of this source tree. |
7 | 7 |
|
8 | | -# pyre-unsafe |
9 | 8 |
|
10 | | - |
11 | | -import torch |
12 | | -from executorch.backends.arm._passes.arm_pass_utils import ( |
13 | | - create_node, |
14 | | - get_param_tensor, |
15 | | - is_param_node, |
16 | | -) |
17 | | -from executorch.exir import ExportedProgram |
18 | 9 | from executorch.exir.dialects._ops import ops as exir_ops |
19 | | -from executorch.exir.pass_base import ExportPass, PassResult |
| 10 | +from executorch.exir.pass_base import ExportPass |
20 | 11 |
|
21 | 12 |
|
22 | 13 | class Conv1dUnsqueezePass(ExportPass): |
23 | 14 | """ |
24 | 15 | This pass is used to change conv1d ops into conv2d since TOSA only |
25 | 16 | supports 2d and 3d convolution. This is done by modifying the graph to do the |
26 | 17 | following: |
27 | | - 1) unsqueeze the convolution's input from 3d to 4d |
| 18 | + 1a) unsqueeze the convolution's input from 3d to 4d |
| 19 | + 1b) unsqueeze the convolution's weight from 3d to 4d |
28 | 20 | 2) perform a conv2d (with a modified version of the original conv1d args) |
29 | 21 | 3) squeeze the output back down to 3d. |
30 | 22 | """ |
31 | 23 |
|
32 | | - def __init__(self, exported_program: ExportedProgram) -> None: |
33 | | - super().__init__() |
34 | | - self.exported_program = exported_program |
35 | | - |
36 | | - def unsqueeze_kernel_weights(self, kernel_node): |
37 | | - """ |
38 | | - Unsqueezes the weights of a conv1d to make it 4 dimensional. |
39 | | -
|
40 | | - Args: |
41 | | - kernel_node: the weights of conv1d node to be unsqueezed |
42 | | - """ |
43 | | - kernel_param_3d = get_param_tensor(self.exported_program, kernel_node) |
44 | | - if kernel_param_3d is None: |
45 | | - raise AssertionError("Expected param tensor for the kernel node") |
46 | | - |
47 | | - kernel_param_4d = torch.nn.Parameter( |
48 | | - data=kernel_param_3d.data.contiguous().unsqueeze(dim=-1), |
49 | | - requires_grad=False, |
| 24 | + def call_operator(self, op, args, kwargs, meta): |
| 25 | + if op != exir_ops.edge.aten.convolution.default: |
| 26 | + return super().call_operator(op, args, kwargs, meta) |
| 27 | + stride = list(args[3]) |
| 28 | + if len(stride) != 1: |
| 29 | + return super().call_operator(op, args, kwargs, meta) |
| 30 | + |
| 31 | + x = args[0] |
| 32 | + x_unsqueezed_shape = list(x.data.shape) + [1] |
| 33 | + x = super().call_operator( |
| 34 | + exir_ops.edge.aten.view_copy.default, (x, x_unsqueezed_shape), {}, meta |
50 | 35 | ) |
51 | 36 |
|
52 | | - if torch._export.utils.is_param(self.exported_program, kernel_node): |
53 | | - parameter_name = self.exported_program.graph_signature.inputs_to_parameters[ |
54 | | - kernel_node.name |
55 | | - ] |
56 | | - self.exported_program.state_dict[parameter_name] = kernel_param_4d |
57 | | - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
58 | | - elif torch._export.utils.is_buffer(self.exported_program, kernel_node): |
59 | | - buffer_name = self.exported_program.graph_signature.inputs_to_buffers[ |
60 | | - kernel_node.name |
61 | | - ] |
62 | | - self.exported_program.state_dict[buffer_name] = kernel_param_4d |
63 | | - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
64 | | - elif torch._export.utils.is_lifted_tensor_constant( |
65 | | - self.exported_program, kernel_node |
66 | | - ): |
67 | | - buffer_name = ( |
68 | | - self.exported_program.graph_signature.inputs_to_lifted_tensor_constants[ |
69 | | - kernel_node.name |
70 | | - ] |
71 | | - ) |
72 | | - self.exported_program.constants[buffer_name] = kernel_param_4d |
73 | | - kernel_node.meta["val"] = kernel_node.meta["val"].data.unsqueeze(dim=-1) |
74 | | - else: |
75 | | - setattr( |
76 | | - kernel_node.graph.owning_module, |
77 | | - kernel_node.target, |
78 | | - kernel_param_4d, |
79 | | - ) |
80 | | - |
81 | | - def call(self, graph_module: torch.fx.GraphModule): |
82 | | - graph = graph_module.graph |
83 | | - node_list = list(graph.nodes) |
84 | | - for node in node_list: |
85 | | - if node.op == "call_function": |
86 | | - if node.target == exir_ops.edge.aten.convolution.default: |
87 | | - stride = list(node.args[3]) |
88 | | - if len(stride) != 1: |
89 | | - # skip conv if it is not 1d |
90 | | - continue |
91 | | - |
92 | | - kernel_node = node.args[1] |
93 | | - |
94 | | - if not is_param_node(self.exported_program, kernel_node): |
95 | | - raise AssertionError( |
96 | | - "Expected op for convolution weight node to be a get_attr node or a parameter" |
97 | | - ) |
| 37 | + w_meta = meta.copy() |
| 38 | + w_meta.data["input_qparams"] = {} |
| 39 | + w_meta.data["output_qparams"] = {} |
98 | 40 |
|
99 | | - # Modify graph such that the conv changes from 1d to 2d |
100 | | - self.unsqueeze_kernel_weights(kernel_node) |
101 | | - |
102 | | - # (b) Extend stride, padding, and dilation for extra dim |
103 | | - node.args = ( |
104 | | - node.args[0], |
105 | | - node.args[1], |
106 | | - node.args[2], |
107 | | - node.args[3] + [1], # stride |
108 | | - node.args[4] + [0], # padding |
109 | | - node.args[5] + [1], # dilation |
110 | | - node.args[6], |
111 | | - node.args[7] + [0], |
112 | | - node.args[8], |
113 | | - ) |
114 | | - |
115 | | - # c. Add unsqueeze to input (3d -> 4d) and squeeze to output (4d -> 3d) |
116 | | - # unsqueeze -> conv2d -> squeeze |
117 | | - with graph.inserting_before(node): |
118 | | - input_node = node.args[0] |
119 | | - unsqueeze_before = create_node( |
120 | | - graph, exir_ops.edge.aten.unsqueeze_copy.default |
121 | | - ) |
122 | | - unsqueeze_before.args = ( |
123 | | - input_node, # Input is node's original input |
124 | | - -1, # Last Dimension |
125 | | - ) |
126 | | - node.replace_input_with(input_node, unsqueeze_before) |
| 41 | + w = args[1] |
| 42 | + w_unsqueezed_shape = list(w.data.shape) + [1] |
| 43 | + w = super().call_operator( |
| 44 | + exir_ops.edge.aten.view_copy.default, (w, w_unsqueezed_shape), {}, w_meta |
| 45 | + ) |
127 | 46 |
|
128 | | - with graph.inserting_after(node): |
129 | | - squeeze_after = create_node( |
130 | | - graph, |
131 | | - exir_ops.edge.aten.squeeze_copy.dims, |
132 | | - ) |
133 | | - squeeze_after.args = ( |
134 | | - node, # Input is the conv node |
135 | | - [-1], # Last dimension |
136 | | - ) |
137 | | - original_users = [ |
138 | | - user for user in node.users if user != squeeze_after |
139 | | - ] |
140 | | - for user in original_users: |
141 | | - user.replace_input_with(node, squeeze_after) |
| 47 | + new_args = ( |
| 48 | + x, |
| 49 | + w, |
| 50 | + args[2], |
| 51 | + args[3] + [1], # stride |
| 52 | + args[4] + [0], # padding |
| 53 | + args[5] + [1], # dilation |
| 54 | + args[6], |
| 55 | + args[7] + [0], |
| 56 | + args[8], |
| 57 | + ) |
| 58 | + x = super().call_operator( |
| 59 | + exir_ops.edge.aten.convolution.default, new_args, kwargs, meta |
| 60 | + ) |
142 | 61 |
|
143 | | - graph_module.recompile() |
144 | | - # Since we are overriding "call", we need to call the parent's "call" |
145 | | - # to retrace the graph and regenerate metadata |
146 | | - graph_module = super().call(graph_module).graph_module |
| 62 | + x_squeezed_shape = list(x.data.shape)[:-1] |
| 63 | + x = super().call_operator( |
| 64 | + exir_ops.edge.aten.view_copy.default, (x, x_squeezed_shape), {}, meta |
| 65 | + ) |
147 | 66 |
|
148 | | - return PassResult(graph_module, True) |
| 67 | + return x |
0 commit comments