Skip to content

Commit 9f3a16d

Browse files
Arm backend: Support for max_pool2d for TOSA 1.0
Add support for MAX_POOL2D TOSA 1.0 Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I3432622ede7b029e78ca8af5c9c71f17f551d4b3
1 parent 804866f commit 9f3a16d

File tree

1 file changed

+60
-5
lines changed

1 file changed

+60
-5
lines changed

backends/arm/operators/op_max_pool2d.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1311
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1412
get_input_qparams,
1513
get_output_qparams,
@@ -19,22 +17,29 @@
1917
register_node_visitor,
2018
)
2119
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
2221

2322

2423
@register_node_visitor
25-
class MaxPool2dVisitor(NodeVisitor):
24+
class MaxPool2dVisitor_0_80(NodeVisitor):
2625
target = "aten.max_pool2d.default"
2726

27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
30+
]
31+
2832
def __init__(self, *args):
2933
super().__init__(*args)
3034

3135
def define_node(
3236
self,
3337
node: torch.fx.Node,
34-
tosa_graph: ts.TosaSerializer,
38+
tosa_graph: Any,
3539
inputs: List[TosaArg],
3640
output: TosaArg,
3741
) -> None:
42+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3843

3944
input_tensor = inputs[0]
4045
kernel_size = inputs[1].special
@@ -80,3 +85,53 @@ def define_node(
8085
[output.name],
8186
attr,
8287
)
88+
89+
90+
@register_node_visitor
91+
class MaxPool2dVisitor(NodeVisitor):
92+
target = "aten.max_pool2d.default"
93+
94+
tosa_specs = [
95+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
96+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
97+
]
98+
99+
def __init__(self, *args):
100+
super().__init__(*args)
101+
102+
def define_node(
103+
self,
104+
node: torch.fx.Node,
105+
tosa_graph: Any,
106+
inputs: List[TosaArg],
107+
output: TosaArg,
108+
) -> None:
109+
110+
import serializer.tosa_serializer as ts # type: ignore
111+
112+
input_tensor = inputs[0]
113+
kernel_size = inputs[1].special
114+
stride = inputs[2].special
115+
116+
try:
117+
pad_size_list = inputs[3].special
118+
pad_size_list = [
119+
pad_size_list[0],
120+
pad_size_list[0],
121+
pad_size_list[1],
122+
pad_size_list[1],
123+
]
124+
except IndexError:
125+
pad_size_list = [0, 0, 0, 0]
126+
127+
attr = ts.TosaSerializerAttribute()
128+
attr.MaxPool2dAttribute(
129+
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1
130+
)
131+
132+
tosa_graph.addOperator(
133+
ts.TosaOp.Op().MAX_POOL2D,
134+
[input_tensor.name],
135+
[output.name],
136+
attr,
137+
)

0 commit comments

Comments
 (0)