Skip to content

Commit f08ae5a

Browse files
committed
[Int16] Add 16a8w support for avg_pool op
Adds 16a8w quantized support for avg_pool op. Differential Revision: [D86236826](https://our.internmc.facebook.com/intern/diff/D86236826/) [ghstack-poisoned]
1 parent 72d8082 commit f08ae5a

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,43 @@ def define_node(
134134
input_zp = 0
135135
output_zp = 0
136136

137+
self._build_generic_avgpool2d(
138+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139+
)
140+
141+
142+
@register_node_visitor
143+
class AvgPool2dVisitor_INT16(AvgPool2dVisitor):
144+
target = "aten.avg_pool2d.default"
145+
146+
tosa_specs = [
147+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
148+
]
149+
150+
def __init__(self, *args):
151+
super().__init__(*args)
152+
153+
def define_node(
154+
self,
155+
node: torch.fx.Node,
156+
tosa_graph: Any,
157+
inputs: List[TosaArg],
158+
output: TosaArg,
159+
) -> None:
160+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
161+
validate_same_dtype(self.target, [inputs[0], output], ts)
162+
validate_valid_dtype(
163+
self.target, [inputs[0], output], ts.DType.INT16, output.tosa_spec
164+
)
165+
166+
accumulator_type = ts.DType.INT32
167+
168+
input_qargs = get_input_qparams(node)
169+
input_zp = input_qargs[0].get_zp_per_tensor()
170+
171+
output_qargs = get_output_qparams(node)
172+
output_zp = output_qargs[0].get_zp_per_tensor()
173+
137174
self._build_generic_avgpool2d(
138175
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139176
)

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
import torch
1414

15+
from executorch.backends.arm.quantizer import arm_quantizer
16+
from executorch.backends.arm.quantizer.arm_quantizer import (
17+
get_symmetric_a16w8_quantization_config,
18+
TOSAQuantizer,
19+
)
1520
from executorch.backends.arm.test import common
1621

1722
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -22,6 +27,8 @@
2227
TosaPipelineINT,
2328
VgfPipeline,
2429
)
30+
from executorch.backends.arm.tosa import TosaSpecification
31+
from executorch.backends.xnnpack.test.tester import Quantize
2532

2633
aten_op = "torch.ops.aten.avg_pool2d.default"
2734
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"
@@ -232,3 +239,97 @@ def test_avg_pool2d_u55_INT_not_delegated(reject_module):
232239
u55_subset=True,
233240
)
234241
pipeline.run()
242+
243+
244+
def get_symmetric_a16w8_avg_pool2d_quantizer(per_channel_quantization=False):
245+
tosa_version = conftest.get_option("tosa_version")
246+
tosa_profiles = {
247+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
248+
}
249+
250+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
251+
quantizer.set_global(
252+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
253+
)
254+
255+
return Quantize(
256+
quantizer,
257+
get_symmetric_a16w8_quantization_config(
258+
is_per_channel=per_channel_quantization
259+
),
260+
)
261+
262+
263+
@common.parametrize("test_module", test_modules)
264+
def test_avg_pool2d_16a8w_tosa_INT(test_module):
265+
"""Test avg_pool2d operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
266+
model, input_tensor = test_module()
267+
per_channel_quantization = False
268+
269+
pipeline = TosaPipelineINT[input_t](
270+
model,
271+
input_tensor,
272+
aten_op,
273+
exir_op=[],
274+
per_channel_quantization=per_channel_quantization,
275+
use_to_edge_transform_and_lower=True,
276+
tosa_extensions=["int16"],
277+
)
278+
279+
pipeline.change_args(
280+
"quantize",
281+
get_symmetric_a16w8_avg_pool2d_quantizer(
282+
per_channel_quantization=per_channel_quantization
283+
),
284+
)
285+
pipeline.run()
286+
287+
288+
@common.parametrize("test_module", test_modules)
289+
@common.XfailIfNoCorstone300
290+
def test_avg_pool2d_16a8w_u55_INT16(test_module):
291+
"""Test avg_pool2d operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
292+
model, input_tensor = test_module()
293+
per_channel_quantization = False
294+
295+
pipeline = EthosU55PipelineINT[input_t](
296+
model,
297+
input_tensor,
298+
aten_op,
299+
exir_op,
300+
per_channel_quantization=per_channel_quantization,
301+
use_to_edge_transform_and_lower=True,
302+
)
303+
304+
pipeline.change_args(
305+
"quantize",
306+
get_symmetric_a16w8_avg_pool2d_quantizer(
307+
per_channel_quantization=per_channel_quantization
308+
),
309+
)
310+
pipeline.run()
311+
312+
313+
@common.parametrize("test_module", test_modules)
314+
@common.XfailIfNoCorstone320
315+
def test_avg_pool2d_16a8w_u85_INT16(test_module):
316+
"""Test avg_pool2d operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
317+
model, input_tensor = test_module()
318+
per_channel_quantization = False
319+
320+
pipeline = EthosU85PipelineINT[input_t](
321+
model,
322+
input_tensor,
323+
aten_op,
324+
exir_op,
325+
per_channel_quantization=per_channel_quantization,
326+
use_to_edge_transform_and_lower=True,
327+
)
328+
329+
pipeline.change_args(
330+
"quantize",
331+
get_symmetric_a16w8_avg_pool2d_quantizer(
332+
per_channel_quantization=per_channel_quantization
333+
),
334+
)
335+
pipeline.run()

0 commit comments

Comments
 (0)