| 
1 |  | -# Copyright 2024 Arm Limited and/or its affiliates.  | 
 | 1 | +# Copyright 2024-2025 Arm Limited and/or its affiliates.  | 
2 | 2 | # All rights reserved.  | 
3 | 3 | #  | 
4 | 4 | # This source code is licensed under the BSD-style license found in the  | 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 | 6 | 
 
  | 
7 | 7 | # pyre-unsafe  | 
8 | 8 | 
 
  | 
9 |  | -from typing import cast, Optional  | 
 | 9 | +from typing import cast  | 
10 | 10 | 
 
  | 
11 | 11 | import torch.fx  | 
 | 12 | +from executorch.backends.arm._passes.arm_pass_utils import create_node  | 
12 | 13 | from executorch.exir.dialects._ops import ops as exir_ops  | 
13 | 14 | from executorch.exir.pass_base import ExportPass, PassResult  | 
14 |  | -from torch._ops import OpOverload  | 
15 | 15 | 
 
  | 
16 | 16 | 
 
  | 
17 | 17 | def conv_remainder(input_length, pad, dilation, weight, stride):  | 
18 | 18 |     """  | 
19 |  | -    Returns the size  | 
 | 19 | +    Returns the remainder of input_length; given the padding, dilation, stride,  | 
 | 20 | +    and kernel size.  | 
20 | 21 |     """  | 
21 | 22 |     return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride  | 
22 | 23 | 
 
  | 
23 | 24 | 
 
  | 
24 |  | -def insert_q_dq_pair(  | 
25 |  | -    graph: torch.fx.Graph,  | 
26 |  | -    anchor: torch.fx.Node,  | 
27 |  | -    q_params: tuple,  | 
28 |  | -):  | 
29 |  | -    with graph.inserting_after(anchor):  | 
30 |  | -        q = create_node(  | 
31 |  | -            graph=graph,  | 
32 |  | -            op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,  | 
33 |  | -            args=(),  # We add the argument last  | 
34 |  | -        )  | 
35 |  | -        q.meta = anchor.meta  | 
36 |  | - | 
37 |  | -    with graph.inserting_after(q):  | 
38 |  | -        dq = create_node(  | 
39 |  | -            graph=graph,  | 
40 |  | -            op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,  | 
41 |  | -            args=(q,) + q_params,  | 
42 |  | -        )  | 
43 |  | -        dq.meta = q.meta  | 
44 |  | - | 
45 |  | -    anchor.replace_all_uses_with(dq)  | 
46 |  | -    # We add this last so the replace all uses above does not replace the quantized  | 
47 |  | -    # node's first use  | 
48 |  | -    q.args = (anchor,) + q_params  | 
49 |  | -    return dq  | 
50 |  | - | 
51 |  | - | 
52 |  | -def create_node(  | 
53 |  | -    graph: torch.fx.Graph,  | 
54 |  | -    op_target: OpOverload,  | 
55 |  | -    args: tuple = (),  | 
56 |  | -    kwargs: Optional[dict] = None,  | 
57 |  | -):  | 
58 |  | -    return graph.create_node(  | 
59 |  | -        "call_function",  | 
60 |  | -        op_target,  | 
61 |  | -        args=args,  | 
62 |  | -        kwargs=kwargs or {},  | 
63 |  | -    )  | 
64 |  | - | 
65 |  | - | 
66 | 25 | class SizeAdjustConv2DPass(ExportPass):  | 
67 | 26 |     """  | 
68 |  | -    Adjust the convolution input size to match perfectly with the  | 
69 |  | -    weight size, padding, stride and dilation parameters.  | 
70 |  | -    This is done by inserting a slice op to remove the uneven end of the input.  | 
 | 27 | +    Adjust the convolution input size to match the kernel size, padding, stride,  | 
 | 28 | +    and dilation parameters. Pytorch allows the input and kernel shape to not  | 
 | 29 | +    "match", in which case the remaining rows/columns are truncated. However,  | 
 | 30 | +    matching the size is a requirement in the TOSA specification. In case the  | 
 | 31 | +    input and kernel shape do not match, the following is done to meet the  | 
 | 32 | +    specification:  | 
 | 33 | +
  | 
 | 34 | +      1) The padding is truncated (done in the node visitor)  | 
 | 35 | +      2) (if neccessary) The input is truncated (done in this pass)."  | 
 | 36 | +
  | 
 | 37 | +    A simple example would be a 2x2 kernel (no padding, stride=2) and a 5x5  | 
 | 38 | +    input:  | 
 | 39 | +
  | 
 | 40 | +    ┌───┬───┬───┬───┬───┐    ┌───┬───┬───┬───┬───┐    ┌───┬───┬───┬───┬───┐  | 
 | 41 | +    │ X │ X │   │   │   │    │   │   │ X │ X │   │    │   │   │   │   │ - │  | 
 | 42 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 43 | +    │ X │ X │   │   │   │    │   │   │ X │ X │   │    │   │   │   │   │ - │  | 
 | 44 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 45 | +    │   │   │   │   │   │ -> │   │   │   │   │   │ -> │ X │ X │   │   │   │ ->  | 
 | 46 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 47 | +    │   │   │   │   │   │    │   │   │   │   │   │    │ X │ X │   │   │   │  | 
 | 48 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 49 | +    │   │   │   │   │   │    │   │   │   │   │   │    │   │   │   │   │   │  | 
 | 50 | +    └───┴───┴───┴───┴───┘    └───┴───┴───┴───┴───┘    └───┴───┴───┴───┴───┘  | 
 | 51 | +         First pass               second pass              third pass  | 
 | 52 | +
  | 
 | 53 | +    ┌───┬───┬───┬───┬───┐    ┌───┬───┬───┬───┬───┐  | 
 | 54 | +    │   │   │   │   │   │    │   │   │   │   │ - │  | 
 | 55 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 56 | +    │   │   │   │   │   │    │   │   │   │   │ - │  | 
 | 57 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 58 | +    │   │   │ X │ X │   │ -> │   │   │   │   │ - │  | 
 | 59 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 60 | +    │   │   │ X │ X │   │    │   │   │   │   │ - │  | 
 | 61 | +    ├───┼───┼───┼───┼───┤    ├───┼───┼───┼───┼───┤  | 
 | 62 | +    │   │   │   │   │   │    │ - │ - │ - │ - │ - │  | 
 | 63 | +    └───┴───┴───┴───┴───┘    └───┴───┴───┴───┴───┘  | 
 | 64 | +         Fourth pass            Unvisited cells  | 
 | 65 | +
  | 
 | 66 | +    Cells that are never visited are marked with `-` and are never considered  | 
 | 67 | +    when the kernel traverses over the input, hence they can be removed.  | 
 | 68 | +
  | 
 | 69 | +    To match the shape of the kernel (and all parameters) with the input, a  | 
 | 70 | +    slice op is inserted to remove the remaining edges (rows and columns) of the  | 
 | 71 | +    input.  | 
71 | 72 |     """  | 
72 | 73 | 
 
  | 
73 | 74 |     conv2d_op = exir_ops.edge.aten.convolution.default  | 
@@ -109,9 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule):  | 
109 | 110 |             with graph_module.graph.inserting_before(node):  | 
110 | 111 |                 last_node = cast(torch.fx.Node, input_node)  | 
111 | 112 |                 for args in slice_args:  | 
112 |  | -                    slice_node = graph.create_node(  | 
113 |  | -                        "call_function", self.slice_op, (last_node,) + args  | 
114 |  | -                    )  | 
 | 113 | +                    slice_node = create_node(graph, self.slice_op, (last_node,) + args)  | 
115 | 114 |                     last_node = slice_node  | 
116 | 115 |                 conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)  | 
117 | 116 |                 modified_graph = True  | 
 | 
0 commit comments