Skip to content

Commit 804866f

Browse files
Arm backend: Support for avg_pool2d for TOSA 1.0
Add support for AVG_POOL2D TOSA 1.0 Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Ic778d4c0cb0cdafae366d36c995992d743200f9f
1 parent 3a54bdd commit 804866f

File tree

1 file changed

+134
-7
lines changed

1 file changed

+134
-7
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,10 @@
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
7-
from typing import List
7+
from typing import Any, List
88

99
import torch
1010

11-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12-
1311
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
1412
get_input_qparams,
1513
get_output_qparams,
@@ -36,14 +34,16 @@ def __init__(self, *args):
3634
def _build_generic_avgpool2d(
3735
self,
3836
node: torch.fx.Node,
39-
tosa_graph: ts.TosaSerializer,
37+
tosa_graph: Any,
4038
inputs: List[TosaArg],
4139
output: TosaArg,
4240
input_zp: int,
4341
output_zp: int,
44-
accumulator_type: ts.DType,
42+
accumulator_type: Any,
4543
) -> None:
4644

45+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
46+
4747
input_tensor = inputs[0]
4848
kernel_size_list = inputs[1].special
4949
stride_size_list = inputs[2].special
@@ -79,10 +79,12 @@ def _build_generic_avgpool2d(
7979
def define_node(
8080
self,
8181
node: torch.fx.Node,
82-
tosa_graph: ts.TosaSerializer,
82+
tosa_graph: Any,
8383
inputs: List[TosaArg],
8484
output: TosaArg,
8585
) -> None:
86+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
87+
8688
input_tensor = inputs[0]
8789
assert input_tensor.dtype == ts.DType.INT8
8890

@@ -110,10 +112,135 @@ class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
110112
def define_node(
111113
self,
112114
node: torch.fx.Node,
113-
tosa_graph: ts.TosaSerializer,
115+
tosa_graph: Any,
114116
inputs: List[TosaArg],
115117
output: TosaArg,
116118
) -> None:
119+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120+
121+
assert (
122+
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
123+
), "Only FP32 and INT8 supported"
124+
125+
if inputs[0].dtype == ts.DType.INT8:
126+
super().define_node(node, tosa_graph, inputs, output)
127+
128+
if inputs[0].dtype == ts.DType.FP32:
129+
accumulator_type = ts.DType.FP32
130+
# Initilize zero point to zero.
131+
input_zp = 0
132+
output_zp = 0
133+
134+
self._build_generic_avgpool2d(
135+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
136+
)
137+
138+
139+
@register_node_visitor
140+
class AvgPool2dVisitor(NodeVisitor):
141+
target = "aten.avg_pool2d.default"
142+
143+
tosa_specs = [
144+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
145+
]
146+
147+
def __init__(self, *args):
148+
super().__init__(*args)
149+
150+
def _build_generic_avgpool2d(
151+
self,
152+
node: torch.fx.Node,
153+
tosa_graph: Any,
154+
inputs: List[TosaArg],
155+
output: TosaArg,
156+
input_zp: int,
157+
output_zp: int,
158+
accumulator_type: Any,
159+
) -> None:
160+
161+
import serializer.tosa_serializer as ts # type: ignore
162+
163+
input_tensor = inputs[0]
164+
kernel_size_list = inputs[1].special
165+
stride_size_list = inputs[2].special
166+
167+
try:
168+
pad_size_list = inputs[3].special
169+
pad_size_list = [
170+
pad_size_list[0],
171+
pad_size_list[0],
172+
pad_size_list[1],
173+
pad_size_list[1],
174+
]
175+
except IndexError:
176+
pad_size_list = [0, 0, 0, 0]
177+
178+
attr = ts.TosaSerializerAttribute()
179+
attr.AvgPool2dAttribute(
180+
kernel=kernel_size_list,
181+
stride=stride_size_list,
182+
pad=pad_size_list,
183+
acc_type=accumulator_type,
184+
)
185+
input_zp_tensor = tosa_graph.addConst(
186+
shape=[1], dtype=output.dtype, vals=[input_zp]
187+
)
188+
output_zp_tensor = tosa_graph.addConst(
189+
shape=[1], dtype=output.dtype, vals=[output_zp]
190+
)
191+
192+
tosa_graph.addOperator(
193+
ts.TosaOp.Op().AVG_POOL2D,
194+
[input_tensor.name, input_zp_tensor.name, output_zp_tensor.name],
195+
[output.name],
196+
attr,
197+
)
198+
199+
def define_node(
200+
self,
201+
node: torch.fx.Node,
202+
tosa_graph: Any,
203+
inputs: List[TosaArg],
204+
output: TosaArg,
205+
) -> None:
206+
import serializer.tosa_serializer as ts # type: ignore
207+
208+
input_tensor = inputs[0]
209+
assert input_tensor.dtype == ts.DType.INT8
210+
211+
accumulator_type = ts.DType.INT32
212+
213+
input_qargs = get_input_qparams(node)
214+
input_zp = input_qargs[0].zp
215+
216+
output_qargs = get_output_qparams(node)
217+
output_zp = output_qargs[0].zp
218+
219+
self._build_generic_avgpool2d(
220+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
221+
)
222+
223+
224+
@register_node_visitor
225+
class AvgPool2dVisitor_FP(AvgPool2dVisitor):
226+
target = "aten.avg_pool2d.default"
227+
228+
tosa_specs = [
229+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
230+
]
231+
232+
def __init__(self, *args):
233+
super().__init__(*args)
234+
235+
def define_node(
236+
self,
237+
node: torch.fx.Node,
238+
tosa_graph: Any,
239+
inputs: List[TosaArg],
240+
output: TosaArg,
241+
) -> None:
242+
import serializer.tosa_serializer as ts # type: ignore
243+
117244
assert (
118245
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
119246
), "Only FP32 and INT8 supported"

0 commit comments

Comments
 (0)