|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 | # pyre-unsafe |
7 | | -from typing import List |
| 7 | +from typing import Any, List |
8 | 8 |
|
9 | 9 | import torch |
10 | 10 |
|
11 | | -import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
12 | | - |
13 | 11 | from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import ( |
14 | 12 | get_input_qparams, |
15 | 13 | get_output_qparams, |
|
19 | 17 | register_node_visitor, |
20 | 18 | ) |
21 | 19 | from executorch.backends.arm.tosa_mapping import TosaArg |
| 20 | +from executorch.backends.arm.tosa_specification import TosaSpecification |
22 | 21 |
|
23 | 22 |
|
24 | 23 | @register_node_visitor |
25 | | -class MaxPool2dVisitor(NodeVisitor): |
| 24 | +class MaxPool2dVisitor_0_80(NodeVisitor): |
26 | 25 | target = "aten.max_pool2d.default" |
27 | 26 |
|
| 27 | + tosa_specs = [ |
| 28 | + TosaSpecification.create_from_string("TOSA-0.80+BI"), |
| 29 | + TosaSpecification.create_from_string("TOSA-0.80+MI"), |
| 30 | + ] |
| 31 | + |
28 | 32 | def __init__(self, *args): |
29 | 33 | super().__init__(*args) |
30 | 34 |
|
31 | 35 | def define_node( |
32 | 36 | self, |
33 | 37 | node: torch.fx.Node, |
34 | | - tosa_graph: ts.TosaSerializer, |
| 38 | + tosa_graph: Any, |
35 | 39 | inputs: List[TosaArg], |
36 | 40 | output: TosaArg, |
37 | 41 | ) -> None: |
| 42 | + import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore |
38 | 43 |
|
39 | 44 | input_tensor = inputs[0] |
40 | 45 | kernel_size = inputs[1].special |
@@ -80,3 +85,53 @@ def define_node( |
80 | 85 | [output.name], |
81 | 86 | attr, |
82 | 87 | ) |
| 88 | + |
| 89 | + |
| 90 | +@register_node_visitor |
| 91 | +class MaxPool2dVisitor(NodeVisitor): |
| 92 | + target = "aten.max_pool2d.default" |
| 93 | + |
| 94 | + tosa_specs = [ |
| 95 | + TosaSpecification.create_from_string("TOSA-1.0+INT"), |
| 96 | + TosaSpecification.create_from_string("TOSA-1.0+FP"), |
| 97 | + ] |
| 98 | + |
| 99 | + def __init__(self, *args): |
| 100 | + super().__init__(*args) |
| 101 | + |
| 102 | + def define_node( |
| 103 | + self, |
| 104 | + node: torch.fx.Node, |
| 105 | + tosa_graph: Any, |
| 106 | + inputs: List[TosaArg], |
| 107 | + output: TosaArg, |
| 108 | + ) -> None: |
| 109 | + |
| 110 | + import serializer.tosa_serializer as ts # type: ignore |
| 111 | + |
| 112 | + input_tensor = inputs[0] |
| 113 | + kernel_size = inputs[1].special |
| 114 | + stride = inputs[2].special |
| 115 | + |
| 116 | + try: |
| 117 | + pad_size_list = inputs[3].special |
| 118 | + pad_size_list = [ |
| 119 | + pad_size_list[0], |
| 120 | + pad_size_list[0], |
| 121 | + pad_size_list[1], |
| 122 | + pad_size_list[1], |
| 123 | + ] |
| 124 | + except IndexError: |
| 125 | + pad_size_list = [0, 0, 0, 0] |
| 126 | + |
| 127 | + attr = ts.TosaSerializerAttribute() |
| 128 | + attr.MaxPool2dAttribute( |
| 129 | + kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1 |
| 130 | + ) |
| 131 | + |
| 132 | + tosa_graph.addOperator( |
| 133 | + ts.TosaOp.Op().MAX_POOL2D, |
| 134 | + [input_tensor.name], |
| 135 | + [output.name], |
| 136 | + attr, |
| 137 | + ) |
0 commit comments