Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,43 @@ def define_node(
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_INT16(AvgPool2dVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target, [inputs[0], output], ts.DType.INT16, output.tosa_spec
)

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].get_zp_per_tensor()

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].get_zp_per_tensor()

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)
100 changes: 100 additions & 0 deletions backends/arm/test/ops/test_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
get_symmetric_a16w8_quantization_config,
TOSAQuantizer,
)
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.test_pipeline import (
Expand All @@ -22,6 +26,8 @@
TosaPipelineINT,
VgfPipeline,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.xnnpack.test.tester import Quantize

aten_op = "torch.ops.aten.avg_pool2d.default"
exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"
Expand Down Expand Up @@ -232,3 +238,97 @@ def test_avg_pool2d_u55_INT_not_delegated(reject_module):
u55_subset=True,
)
pipeline.run()


def get_symmetric_a16w8_avg_pool2d_quantizer(per_channel_quantization=False):
tosa_version = conftest.get_option("tosa_version")
tosa_profiles = {
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"),
}

quantizer = TOSAQuantizer(tosa_profiles[tosa_version])
quantizer.set_global(
get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization)
)

return Quantize(
quantizer,
get_symmetric_a16w8_quantization_config(
is_per_channel=per_channel_quantization
),
)


@common.parametrize("test_module", test_modules)
def test_avg_pool2d_16a8w_tosa_INT(test_module):
"""Test avg_pool2d operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
per_channel_quantization = False

pipeline = TosaPipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op=[],
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
tosa_extensions=["int16"],
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_avg_pool2d_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone300
def test_avg_pool2d_16a8w_u55_INT16(test_module):
"""Test avg_pool2d operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
per_channel_quantization = False

pipeline = EthosU55PipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_avg_pool2d_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()


@common.parametrize("test_module", test_modules)
@common.XfailIfNoCorstone320
def test_avg_pool2d_16a8w_u85_INT16(test_module):
"""Test avg_pool2d operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
model, input_tensor = test_module()
per_channel_quantization = False

pipeline = EthosU85PipelineINT[input_t](
model,
input_tensor,
aten_op,
exir_op,
per_channel_quantization=per_channel_quantization,
use_to_edge_transform_and_lower=True,
)

pipeline.change_args(
"quantize",
get_symmetric_a16w8_avg_pool2d_quantizer(
per_channel_quantization=per_channel_quantization
),
)
pipeline.run()
Loading