From a47154b867087e9172fff3e7437cfbadb23f00a7 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Tue, 30 Sep 2025 11:22:12 -0700 Subject: [PATCH] Adding avgpool2d (#14703) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/14703 Continued support of custom cadence ops Differential Revision: D83579062 --- backends/cadence/aot/ops_registrations.py | 6 +- backends/cadence/aot/ref_implementations.py | 52 ++++++ .../aot/tests/test_ref_implementations.py | 173 ++++++++++++++++++ 3 files changed, 228 insertions(+), 3 deletions(-) diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 33f9c697818..e3009163d62 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -2244,10 +2244,10 @@ def avg_pool2d_meta( kernel_size: Tuple[int], stride: Tuple[int], padding: Tuple[int], - ceil_mode: bool, - count_include_pad: Optional[bool] = True, + ceil_mode: bool = False, + count_include_pad: bool = True, divisor_override: Optional[int] = None, - in_zero_point: Optional[int] = None, + in_zero_point: Optional[torch.Tensor] = None, channel_last: bool = False, ) -> torch.Tensor: # Use torch native meta kernels when operator semantics are similar diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index b45023c2808..312bed89315 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -978,6 +978,58 @@ def convolution( return conv_out +@impl(m, "avg_pool2d") +def avg_pool2d( + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + stride: tuple[int, int], + padding: tuple[int, int], + ceil_mode: bool = False, + count_include_pad: bool = False, + divisor_override: int | None = None, + in_zero_point: torch.Tensor | None = None, + channel_last: bool = False, +) -> torch.Tensor: + if channel_last: + raise NotImplementedError("Channel last is not yet supported for avg_pool2d") + + in_dtype = input_tensor.dtype + pad_h, pad_w = padding + if in_zero_point is not None: + # Avg pool2d does not allow non-0 padding, + # so we manually pad the input + pad_value = in_zero_point.item() + if not count_include_pad: + # To simulate this, just pad with 0s + pad_value = 0 + + input_tensor = torch.nn.functional.pad( + input_tensor, + (pad_w, pad_w, pad_h, pad_h), + mode="constant", + value=pad_value, + ).float() + + padding = (0, 0) + + out = torch.nn.functional.avg_pool2d( + input_tensor, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + + if in_zero_point is not None: + min_val = torch.iinfo(in_dtype).min + max_val = torch.iinfo(in_dtype).max + out = torch.clamp(torch.round(out), min_val, max_val) + + return out.to(in_dtype) + + def quantized_relu_common( X: torch.Tensor, X_zero_point: torch.Tensor | int, diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 606be9098d6..32e9b43e68e 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1533,3 +1533,176 @@ def test_convolution( torch.equal(output, expected_output), f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", ) + + @expand( + [ + # Basic non-quantized average pooling + ( + "basic_non_quantized", + torch.tensor( + [ + [ + [ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ] + ] + ], + dtype=torch.float32, + ), # input: 1x1x4x4 + (2, 2), # kernel_size + (2, 2), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + None, # divisor_override + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[3.5, 5.5], [11.5, 13.5]]]], dtype=torch.float32 + ), # expected: average of 2x2 blocks + ), + # Non-quantized with count_include_pad=True and padding + ( + "non_quantized_count_include_pad", + torch.tensor( + [[[[1.0, 2.0], [3.0, 4.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (3, 3), # kernel_size (larger than input) + (1, 1), # stride + (1, 1), # padding + False, # ceil_mode + True, # count_include_pad=True + None, # divisor_override + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[2.5, 2.5], [2.5, 2.5]]]], + dtype=torch.float32, + ), + ), + # Non-quantized with divisor_override + ( + "non_quantized_divisor_override", + torch.tensor( + [[[[2.0, 4.0], [6.0, 8.0]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + 2, # divisor_override (instead of 4) + None, # in_zero_point (non-quantized) + False, # channel_last + torch.tensor( + [[[[10.0]]]], dtype=torch.float32 + ), # expected: (2+4+6+8)/2 = 10 + ), + # Quantized with non-zero zero_point and padding + ( + "quantized_nonzero_zero_point", + torch.tensor( + [[[[130, 132], [134, 136]]]], dtype=torch.uint8 + ), # input: 1x1x2x2, values around zero_point=128 + (3, 3), # kernel_size + (1, 1), # stride + (1, 1), # padding + False, # ceil_mode + True, # count_include_pad=True + None, # divisor_override + 128, # in_zero_point=128 (padded areas will have this value) + False, # channel_last + torch.tensor( + [[[[130, 130], [130, 130]]]], dtype=torch.uint8 + ), # expected: averages including padded zero_point values + ), + # Quantized with divisor_override + ( + "quantized_divisor_override", + torch.tensor( + [[[[64, 96], [128, 160]]]], dtype=torch.float32 + ), # input: 1x1x2x2 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + 2, # divisor_override (instead of 4) + None, # in_zero_point=None + False, # channel_last + torch.tensor( + [[[[224]]]], dtype=torch.float32 + ), # expected: (64+96+128+160)/2 = 224 + ), + # Large values that need clamping + ( + "quantized_clamping_test", + torch.tensor( + [[[[120, 125], [125, 127]]]], dtype=torch.int8 + ), # input: 1x1x2x2, large values for int8 + (2, 2), # kernel_size + (1, 1), # stride + (0, 0), # padding + False, # ceil_mode + False, # count_include_pad + None, # divisor_override + 0, # in_zero_point=0 + False, # channel_last + torch.tensor( + [[[[124]]]], dtype=torch.int8 + ), # expected: (120+125+125+127)/4 = 124.25 -> 124, within int8 range + ), + ] + ) + def test_avg_pool2d( + self, + name: str, + input_tensor: torch.Tensor, + kernel_size: tuple[int, int], + stride: tuple[int, int], + padding: tuple[int, int], + ceil_mode: bool, + count_include_pad: bool, + divisor_override: int | None, + in_zero_point: int | None, + channel_last: bool, + expected_output: torch.Tensor, + ) -> None: + output = torch.ops.cadence.avg_pool2d( + input_tensor, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + in_zero_point if in_zero_point is None else torch.tensor([in_zero_point]), + channel_last, + ) + + # Verify output properties + self.assertEqual( + output.dtype, + input_tensor.dtype, + f"Output dtype should match input dtype in {name}", + ) + self.assertEqual( + output.shape, + expected_output.shape, + f"Output shape should match expected shape in {name}", + ) + + # Verify output matches expected values + if input_tensor.dtype.is_floating_point: + self.assertTrue( + torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + ) + else: + self.assertTrue( + torch.equal(output, expected_output), + f"Output values don't match expected in {name}. Got {output}, expected {expected_output}", + )