diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 29d5c3bf635..04a97bb7a84 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -134,6 +134,43 @@ def define_node( input_zp = 0 output_zp = 0 + self._build_generic_avgpool2d( + node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type + ) + + +@register_node_visitor +class AvgPool2dVisitor_INT16(AvgPool2dVisitor): + target = "aten.avg_pool2d.default" + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + ] + + def __init__(self, *args): + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, [inputs[0], output], ts.DType.INT16, output.tosa_spec + ) + + accumulator_type = ts.DType.INT32 + + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].get_zp_per_tensor() + + output_qargs = get_output_qparams(node) + output_zp = output_qargs[0].get_zp_per_tensor() + self._build_generic_avgpool2d( node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type ) diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 8310d1e40a4..98edf3acd4a 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -12,6 +12,10 @@ import torch +from executorch.backends.arm.quantizer.arm_quantizer import ( + get_symmetric_a16w8_quantization_config, + TOSAQuantizer, +) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( @@ -22,6 +26,8 @@ TosaPipelineINT, VgfPipeline, ) +from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.xnnpack.test.tester import Quantize aten_op = "torch.ops.aten.avg_pool2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" @@ -232,3 +238,97 @@ def test_avg_pool2d_u55_INT_not_delegated(reject_module): u55_subset=True, ) pipeline.run() + + +def get_symmetric_a16w8_avg_pool2d_quantizer(per_channel_quantization=False): + tosa_version = conftest.get_option("tosa_version") + tosa_profiles = { + "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT+int16"), + } + + quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) + quantizer.set_global( + get_symmetric_a16w8_quantization_config(is_per_channel=per_channel_quantization) + ) + + return Quantize( + quantizer, + get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ), + ) + + +@common.parametrize("test_module", test_modules) +def test_avg_pool2d_16a8w_tosa_INT(test_module): + """Test avg_pool2d operation with 16A8W quantization (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + per_channel_quantization = False + + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op=[], + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + tosa_extensions=["int16"], + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_avg_pool2d_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone300 +def test_avg_pool2d_16a8w_u55_INT16(test_module): + """Test avg_pool2d operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + per_channel_quantization = False + + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_avg_pool2d_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run() + + +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_avg_pool2d_16a8w_u85_INT16(test_module): + """Test avg_pool2d operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=per_channel_quantization, + use_to_edge_transform_and_lower=True, + ) + + pipeline.change_args( + "quantize", + get_symmetric_a16w8_avg_pool2d_quantizer( + per_channel_quantization=per_channel_quantization + ), + ) + pipeline.run()