|
1 | | -# Copyright 2024 Arm Limited and/or its affiliates. |
| 1 | +# Copyright 2024-2025 Arm Limited and/or its affiliates. |
2 | 2 | # All rights reserved. |
3 | 3 | # |
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 | # pyre-unsafe |
8 | 8 |
|
9 | | -from typing import cast, Optional |
| 9 | +from typing import cast |
10 | 10 |
|
11 | 11 | import torch.fx |
| 12 | +from executorch.backends.arm._passes.arm_pass_utils import create_node |
12 | 13 | from executorch.exir.dialects._ops import ops as exir_ops |
13 | 14 | from executorch.exir.pass_base import ExportPass, PassResult |
14 | | -from torch._ops import OpOverload |
15 | 15 |
|
16 | 16 |
|
17 | 17 | def conv_remainder(input_length, pad, dilation, weight, stride): |
18 | 18 | """ |
19 | | - Returns the size |
| 19 | + Returns the remainder of input_length; given the padding, dilation, stride, |
| 20 | + and kernel size. |
20 | 21 | """ |
21 | 22 | return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride |
22 | 23 |
|
23 | 24 |
|
24 | | -def insert_q_dq_pair( |
25 | | - graph: torch.fx.Graph, |
26 | | - anchor: torch.fx.Node, |
27 | | - q_params: tuple, |
28 | | -): |
29 | | - with graph.inserting_after(anchor): |
30 | | - q = create_node( |
31 | | - graph=graph, |
32 | | - op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, |
33 | | - args=(), # We add the argument last |
34 | | - ) |
35 | | - q.meta = anchor.meta |
36 | | - |
37 | | - with graph.inserting_after(q): |
38 | | - dq = create_node( |
39 | | - graph=graph, |
40 | | - op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, |
41 | | - args=(q,) + q_params, |
42 | | - ) |
43 | | - dq.meta = q.meta |
44 | | - |
45 | | - anchor.replace_all_uses_with(dq) |
46 | | - # We add this last so the replace all uses above does not replace the quantized |
47 | | - # node's first use |
48 | | - q.args = (anchor,) + q_params |
49 | | - return dq |
50 | | - |
51 | | - |
52 | | -def create_node( |
53 | | - graph: torch.fx.Graph, |
54 | | - op_target: OpOverload, |
55 | | - args: tuple = (), |
56 | | - kwargs: Optional[dict] = None, |
57 | | -): |
58 | | - return graph.create_node( |
59 | | - "call_function", |
60 | | - op_target, |
61 | | - args=args, |
62 | | - kwargs=kwargs or {}, |
63 | | - ) |
64 | | - |
65 | | - |
66 | 25 | class SizeAdjustConv2DPass(ExportPass): |
67 | 26 | """ |
68 | | - Adjust the convolution input size to match perfectly with the |
69 | | - weight size, padding, stride and dilation parameters. |
70 | | - This is done by inserting a slice op to remove the uneven end of the input. |
| 27 | + Adjust the convolution input size to match the kernel size, padding, stride, |
| 28 | + and dilation parameters. Pytorch allows the input and kernel shape to not |
| 29 | + "match", in which case the remaining rows/columns are truncated. However, |
| 30 | + matching the size is a requirement in the TOSA specification. In case the |
| 31 | + input and kernel shape do not match, the following is done to meet the |
| 32 | + specification: |
| 33 | +
|
| 34 | + 1) The padding is truncated (done in the node visitor) |
| 35 | + 2) (if neccessary) The input is truncated (done in this pass)." |
| 36 | +
|
| 37 | + A simple example would be a 2x2 kernel (no padding, stride=2) and a 5x5 |
| 38 | + input: |
| 39 | +
|
| 40 | + ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ |
| 41 | + │ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │ |
| 42 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 43 | + │ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │ |
| 44 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 45 | + │ │ │ │ │ │ -> │ │ │ │ │ │ -> │ X │ X │ │ │ │ -> |
| 46 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 47 | + │ │ │ │ │ │ │ │ │ │ │ │ │ X │ X │ │ │ │ |
| 48 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 49 | + │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ |
| 50 | + └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ |
| 51 | + First pass second pass third pass |
| 52 | +
|
| 53 | + ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ |
| 54 | + │ │ │ │ │ │ │ │ │ │ │ - │ |
| 55 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 56 | + │ │ │ │ │ │ │ │ │ │ │ - │ |
| 57 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 58 | + │ │ │ X │ X │ │ -> │ │ │ │ │ - │ |
| 59 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 60 | + │ │ │ X │ X │ │ │ │ │ │ │ - │ |
| 61 | + ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ |
| 62 | + │ │ │ │ │ │ │ - │ - │ - │ - │ - │ |
| 63 | + └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ |
| 64 | + Fourth pass Unvisited cells |
| 65 | +
|
| 66 | + Cells that are never visited are marked with `-` and are never considered |
| 67 | + when the kernel traverses over the input, hence they can be removed. |
| 68 | +
|
| 69 | + To match the shape of the kernel (and all parameters) with the input, a |
| 70 | + slice op is inserted to remove the remaining edges (rows and columns) of the |
| 71 | + input. |
71 | 72 | """ |
72 | 73 |
|
73 | 74 | conv2d_op = exir_ops.edge.aten.convolution.default |
@@ -109,9 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule): |
109 | 110 | with graph_module.graph.inserting_before(node): |
110 | 111 | last_node = cast(torch.fx.Node, input_node) |
111 | 112 | for args in slice_args: |
112 | | - slice_node = graph.create_node( |
113 | | - "call_function", self.slice_op, (last_node,) + args |
114 | | - ) |
| 113 | + slice_node = create_node(graph, self.slice_op, (last_node,) + args) |
115 | 114 | last_node = slice_node |
116 | 115 | conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node) |
117 | 116 | modified_graph = True |
|
0 commit comments