Skip to content

Commit 0cc7e9b

Browse files
committed
[Int16] Add 16a8w support for avg_pool op
Pull Request resolved: #15585 Adds 16a8w quantized support for avg_pool op. ghstack-source-id: 320983461 @exported-using-ghexport Differential Revision: [D86236826](https://our.internmc.facebook.com/intern/diff/D86236826/)
1 parent 98afdeb commit 0cc7e9b

File tree

2 files changed

+137
-0
lines changed

2 files changed

+137
-0
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,43 @@ def define_node(
134134
input_zp = 0
135135
output_zp = 0
136136

137+
self._build_generic_avgpool2d(
138+
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139+
)
140+
141+
142+
@register_node_visitor
143+
class AvgPool2dVisitor_INT16(AvgPool2dVisitor):
144+
target = "aten.avg_pool2d.default"
145+
146+
tosa_specs = [
147+
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
148+
]
149+
150+
def __init__(self, *args):
151+
super().__init__(*args)
152+
153+
def define_node(
154+
self,
155+
node: torch.fx.Node,
156+
tosa_graph: Any,
157+
inputs: List[TosaArg],
158+
output: TosaArg,
159+
) -> None:
160+
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
161+
validate_same_dtype(self.target, [inputs[0], output], ts)
162+
validate_valid_dtype(
163+
self.target, [inputs[0], output], ts.DType.INT16, output.tosa_spec
164+
)
165+
166+
accumulator_type = ts.DType.INT32
167+
168+
input_qargs = get_input_qparams(node)
169+
input_zp = input_qargs[0].get_zp_per_tensor()
170+
171+
output_qargs = get_output_qparams(node)
172+
output_zp = output_qargs[0].get_zp_per_tensor()
173+
137174
self._build_generic_avgpool2d(
138175
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
139176
)

backends/arm/test/ops/test_avg_pool2d.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212

1313
import torch
1414

15+
from executorch.backends.arm.quantizer.arm_quantizer import (
16+
get_symmetric_a16w8_quantization_config,
17+
TOSAQuantizer,
18+
)
1519
from executorch.backends.arm.test import common
1620

1721
from executorch.backends.arm.test.tester.test_pipeline import (
@@ -22,6 +26,8 @@
2226
TosaPipelineINT,
2327
VgfPipeline,
2428
)
29+
from executorch.backends.arm.tosa import TosaSpecification
30+
from executorch.backends.xnnpack.test.tester import Quantize
2531

2632
aten_op = "torch.ops.aten.avg_pool2d.default"
2733
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"
@@ -232,3 +238,97 @@ def test_avg_pool2d_u55_INT_not_delegated(reject_module):
232238
u55_subset=True,
233239
)
234240
pipeline.run()
241+
242+
243+
def get_symmetric_a16w8_avg_pool2d_quantizer(per_channel_quantization=False):
244+
tosa_version = conftest.get_option("tosa_version")
245+
tosa_profiles = {
246+
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
247+
}
248+
249+
quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
250+
quantizer.set_global(
251+
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
252+
)
253+
254+
return Quantize(
255+
quantizer,
256+
get_symmetric_a16w8_quantization_config(
257+
is_per_channel=per_channel_quantization
258+
),
259+
)
260+
261+
262+
@common.parametrize("test_module", test_modules)
263+
def test_avg_pool2d_16a8w_tosa_INT(test_module):
264+
"""Test avg_pool2d operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
265+
model, input_tensor = test_module()
266+
per_channel_quantization = False
267+
268+
pipeline = TosaPipelineINT[input_t](
269+
model,
270+
input_tensor,
271+
aten_op,
272+
exir_op=[],
273+
per_channel_quantization=per_channel_quantization,
274+
use_to_edge_transform_and_lower=True,
275+
tosa_extensions=["int16"],
276+
)
277+
278+
pipeline.change_args(
279+
"quantize",
280+
get_symmetric_a16w8_avg_pool2d_quantizer(
281+
per_channel_quantization=per_channel_quantization
282+
),
283+
)
284+
pipeline.run()
285+
286+
287+
@common.parametrize("test_module", test_modules)
288+
@common.XfailIfNoCorstone300
289+
def test_avg_pool2d_16a8w_u55_INT16(test_module):
290+
"""Test avg_pool2d operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
291+
model, input_tensor = test_module()
292+
per_channel_quantization = False
293+
294+
pipeline = EthosU55PipelineINT[input_t](
295+
model,
296+
input_tensor,
297+
aten_op,
298+
exir_op,
299+
per_channel_quantization=per_channel_quantization,
300+
use_to_edge_transform_and_lower=True,
301+
)
302+
303+
pipeline.change_args(
304+
"quantize",
305+
get_symmetric_a16w8_avg_pool2d_quantizer(
306+
per_channel_quantization=per_channel_quantization
307+
),
308+
)
309+
pipeline.run()
310+
311+
312+
@common.parametrize("test_module", test_modules)
313+
@common.XfailIfNoCorstone320
314+
def test_avg_pool2d_16a8w_u85_INT16(test_module):
315+
"""Test avg_pool2d operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
316+
model, input_tensor = test_module()
317+
per_channel_quantization = False
318+
319+
pipeline = EthosU85PipelineINT[input_t](
320+
model,
321+
input_tensor,
322+
aten_op,
323+
exir_op,
324+
per_channel_quantization=per_channel_quantization,
325+
use_to_edge_transform_and_lower=True,
326+
)
327+
328+
pipeline.change_args(
329+
"quantize",
330+
get_symmetric_a16w8_avg_pool2d_quantizer(
331+
per_channel_quantization=per_channel_quantization
332+
),
333+
)
334+
pipeline.run()

0 commit comments

Comments
 (0)