diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 406f47c66dc..28721c86d0f 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -20,6 +20,7 @@ from .convert_split_to_slice import ConvertSplitToSlicePass # noqa from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa from .convert_to_clamp import ConvertToClampPass # noqa +from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa from .decompose_div_pass import DecomposeDivPass # noqa from .decompose_embedding_pass import DecomposeEmbeddingPass # noqa # noqa diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index d386da3ed72..fcfb62bb803 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -23,6 +23,7 @@ ConvertSplitToSlicePass, ConvertSqueezesToViewPass, ConvertToClampPass, + DecomposeAvgPool2d, DecomposeCosineSimilarityPass, DecomposeDivPass, DecomposeEmbeddingPass, @@ -63,7 +64,6 @@ UnsqueezeBeforeRepeatPass, UnsqueezeScalarPlaceholdersPass, ) - from executorch.backends.arm.tosa_specification import ( TosaLoweringContext, TosaSpecification, @@ -115,6 +115,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul if self.tosa_spec.is_U55_subset: self.add_pass(BroadcastArgsPass()) self.add_pass(DecomposeLinearPass()) + self.add_pass(DecomposeAvgPool2d()) self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(RemoveClonePass()) @@ -172,6 +173,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(RetraceFoldedDtypesPass()) self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program)) self.add_pass(MatchArgRanksPass(exported_program)) + self.add_pass(DecomposeAvgPool2d()) self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(RemoveClonePass()) @@ -232,6 +234,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeLinearVectorNormPass()) self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeSiluPass()) + self.add_pass(DecomposeAvgPool2d()) if self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 diff --git a/backends/arm/_passes/decompose_avg_pool2d.py b/backends/arm/_passes/decompose_avg_pool2d.py new file mode 100644 index 00000000000..0eb3ce34ecd --- /dev/null +++ b/backends/arm/_passes/decompose_avg_pool2d.py @@ -0,0 +1,121 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass + +edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) +aten_div_ops = (torch.ops.aten.avg_pool2d.default,) + + +def get_decomposition(op) -> tuple: + if op in edge_div_ops: + return ( + exir_ops.edge.aten.full.default, + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.avg_pool2d.default, + exir_ops.edge.aten.mul.Tensor, + ) + if op in aten_div_ops: + return ( + torch.ops.aten.full.default, + torch.ops.aten.cat.default, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.mul.Tensor, + ) + raise RuntimeError(f"Can't get div decomposition for op {op}") + + +class DecomposeAvgPool2d(ExportPass): + """ """ + + def call_operator(self, op, args, kwargs, meta): + if op not in (edge_div_ops + aten_div_ops): + return super().call_operator(op, args, kwargs, meta) + + full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) + + x = args[0] + kernel_h, kernel_w = args[1] + kernel_size = kernel_h * kernel_w + stride_h, stride_w = args[2] + pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) + ceil_mode = args[4] if len(args) > 4 else False + count_include_pad = args[5] if len(args) > 5 else True + divisor_override = args[6] if len(args) > 6 else None + + n, c, h, w = x.data.shape + post_pad_w, post_pad_h = (0, 0) + + # Count_include_pad == False means that we use a different divisor for edge elements + # When divisor_override is set, this will be overriden anyways. + # It is easier to replace a constant divisor, so set count_include_pad == True + if divisor_override is not None: + count_include_pad = True + + # Add width padding manually if count_include_pad + if count_include_pad and pad_w > 0: + pre_pad_shape = [n, c, h, pad_w] + pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + + if ceil_mode and divisor_override is None: + post_pad_w = pad_w + else: + post_pad_w = adjust_pooling_pad_if_needed( + w, kernel_w, stride_w, pad_w, ceil_mode + ) + + if post_pad_w > 0: + post_pad_shape = [n, c, h, post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta) + new_pad_w = 0 + + # Add height padding manually if count_include_pad + if count_include_pad and pad_h > 0: + pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] + pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) + + if ceil_mode and divisor_override is None: + post_pad_h = pad_h + else: + post_pad_h = adjust_pooling_pad_if_needed( + h, kernel_h, stride_h, pad_h, ceil_mode + ) + + if post_pad_h > 0: + post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] + post_pad = super().call_operator( + full_op, (post_pad_shape, 0.0), kwargs, meta + ) + cat_nodes = [pre_pad, x, post_pad] + else: + cat_nodes = [pre_pad, x] + + x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta) + new_pad_h = 0 + + avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False) + x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta) + + # Multiply by factor (kernel_size / divisor_override) if divisor_override + if divisor_override is not None and divisor_override != kernel_size: + override_multiplier = super().call_operator( + full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta + ) + x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta) + + return x diff --git a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py index 667903c7095..ff6db260099 100644 --- a/backends/arm/_passes/decompose_maxpool2d_with_dilation.py +++ b/backends/arm/_passes/decompose_maxpool2d_with_dilation.py @@ -36,6 +36,7 @@ def call_operator(self, op, args, kwargs, meta): stride = args[2] padding = args[3] if len(args) >= 4 else 0 dilation = args[4] if len(args) >= 5 else 1 + ceil_mode = args[5] if len(args) == 6 else False # Normalize attributes pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding @@ -45,12 +46,9 @@ def call_operator(self, op, args, kwargs, meta): ) s_h, s_w = (stride, stride) if isinstance(stride, int) else stride - # If no dilation: call EXIR edge op with only supported args (x, kernel, stride[, padding]) + # If no dilation: call EXIR edge op if d_h == 1 and d_w == 1: - minimal_args = [x, kernel_size, stride] - # only include padding if non-zero - if (pad_h, pad_w) != (0, 0): - minimal_args.append((pad_h, pad_w)) + minimal_args = [x, kernel_size, stride, padding, dilation, ceil_mode] return super().call_operator(op, tuple(minimal_args), {}, meta) # Compute padded and packed dimensions for dilation > 1 @@ -102,7 +100,7 @@ def call_operator(self, op, args, kwargs, meta): if is_with_indices else exir_ops.edge.aten.max_pool2d.default ) - pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0)) + pool_args = (x2, (k_h, k_w), (s_h, s_w), (0, 0), 1, ceil_mode) pool_out = super().call_operator( pool_edge_op, pool_args, diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index 9db58f663d3..677436ddc50 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -12,6 +12,9 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) +from executorch.backends.arm.operators.operator_validation_utils import ( + adjust_pooling_pad_if_needed, +) from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops @@ -56,25 +59,42 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): input_arg = get_first_fake_tensor(input_arg) shape = input_arg.data.shape # type: ignore[union-attr] + # Calculate padding used in the final TOSA operator kernel = cast(tuple[int, int], node.args[1]) stride = cast(tuple[int, int], node.args[2]) - if len(node.args) > 3: - padding = cast(tuple[int, int], node.args[3]) - # Padding case - if not all(1 <= k <= 8 for k in kernel) and not all( - v == 0 for v in padding - ): - self.reporter.report_reject( - node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}" - ) - return False + padding = cast(tuple[int, int], node.args[3]) if len(node.args) > 3 else (0, 0) + ceil_mode = cast(bool, node.args[4]) if len(node.args) > 4 else False + count_include_pad = cast(bool, node.args[5]) if len(node.args) > 5 else True + divisor_override = cast(int, node.args[6]) if len(node.args) > 6 else None + + # If count_include_pad is True or divior_override is given, padding is applied + # by concating zero-elements rather than setting it in the avg_pool op. + if count_include_pad or divisor_override is not None: + tosa_padding = (0, 0, 0, 0) + # Otherwise, calculate the padding as done in the node visitor else: - if not kernel_check(kernel): - self.reporter.report_reject( - node, - f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", - ) - return False + post_pad_h = adjust_pooling_pad_if_needed( + shape[2], kernel[0], stride[0], padding[0], ceil_mode + ) + post_pad_w = adjust_pooling_pad_if_needed( + shape[3], kernel[1], stride[1], padding[1], ceil_mode + ) + tosa_padding = (padding[0], post_pad_h, padding[1], post_pad_w) + + if not all(1 <= k <= 8 for k in kernel) and not all( + v == 0 for v in tosa_padding + ): + self.reporter.report_reject( + node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}" + ) + return False + + if not kernel_check(kernel): + self.reporter.report_reject( + node, + f"Avgpool2d needs kernel_y < 256, kernel_x*kernel_y<=65536, got {kernel}", + ) + return False if not dim_check(shape): self.reporter.report_reject( diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 2032a820ac0..f839ca380ec 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -54,6 +54,11 @@ def _build_generic_avgpool2d( kernel_size_list = inputs[1].special stride_size_list = inputs[2].special + if len(inputs) > 4: + ceil_mode = bool(inputs[4].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -71,12 +76,14 @@ def _build_generic_avgpool2d( kernel_size_list[0], stride_size_list[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size_list[1], stride_size_list[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() @@ -105,7 +112,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec @@ -141,7 +148,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -192,6 +199,11 @@ def _build_generic_avgpool2d( kernel_size_list = inputs[1].special stride_size_list = inputs[2].special + if len(inputs) > 4: + ceil_mode = bool(inputs[4].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -209,12 +221,14 @@ def _build_generic_avgpool2d( kernel_size_list[0], stride_size_list[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size_list[1], stride_size_list[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() @@ -247,7 +261,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec @@ -286,7 +300,7 @@ def define_node( ) -> None: import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4, 6]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 35180330d80..b3c779477ca 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -47,7 +47,7 @@ def define_node( ) -> None: import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -60,6 +60,10 @@ def define_node( kernel_size = inputs[1].special stride = inputs[2].special + if len(inputs) == 6: + ceil_mode = bool(inputs[5].number) + else: + ceil_mode = False try: pad_size_list = inputs[3].special pad_size_list = [ @@ -68,7 +72,7 @@ def define_node( pad_size_list[1], pad_size_list[1], ] - except IndexError: + except (IndexError, AttributeError): pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary @@ -77,12 +81,14 @@ def define_node( kernel_size[0], stride[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], pad_size_list[3], + ceil_mode, ) accumulator_type = output.dtype @@ -138,7 +144,7 @@ def define_node( import serializer.tosa_serializer as ts # type: ignore - validate_num_inputs(self.target, inputs, [3, 4]) + validate_num_inputs(self.target, inputs, [3, 4, 5, 6]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( self.target, @@ -151,6 +157,11 @@ def define_node( kernel_size = inputs[1].special stride = inputs[2].special + if len(inputs) == 6: + ceil_mode = bool(inputs[5].number) + else: + ceil_mode = False + try: pad_size_list = inputs[3].special pad_size_list = [ @@ -159,7 +170,7 @@ def define_node( pad_size_list[1], pad_size_list[1], ] - except IndexError: + except (IndexError, AttributeError): pad_size_list = [0, 0, 0, 0] # Adjust the padding as necessary @@ -168,12 +179,14 @@ def define_node( kernel_size[0], stride[0], pad_size_list[1], + ceil_mode, ) pad_size_list[3] = adjust_pooling_pad_if_needed( input_tensor.shape[3], kernel_size[1], stride[1], pad_size_list[3], + ceil_mode, ) attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/operator_validation_utils.py b/backends/arm/operators/operator_validation_utils.py index 2dea9e2874b..fde76f31c7a 100644 --- a/backends/arm/operators/operator_validation_utils.py +++ b/backends/arm/operators/operator_validation_utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from math import ceil, floor from typing import Any, List, Optional from executorch.backends.arm.operators.node_visitor import NodeVisitor @@ -183,11 +184,18 @@ def validate_valid_dtype( def adjust_pooling_pad_if_needed( - input_size: int, kernel_size: int, stride: int, pad: int + input_size: int, kernel_size: int, stride: int, pad: int, ceil_mode: bool ) -> int: """ - Calculates the padding that needs to be removed to a pooling window to make it - divisible by the kernels stride. All inputs should correspond to the same dimension. + The Aten pooling ops has one value 'pad' per dimension to specify padding, but they + do not require input and output sizes to match up perfectly. Instead, the output + size is rounded up or down depending on ceil_mode, and padding at the end of the + input is automatically added or removed. TOSA on the other hand specifies two + padding values, one for pre-padding and one for post-padding, and these must satisfy + + output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + + This function returns the post_pad value required to satisfy the above condition. Parameters: ----------- @@ -205,15 +213,16 @@ def adjust_pooling_pad_if_needed( Output: ------- - An int, representing the padding to remove to make the window divisible. + An int, giving the post-padding to use for the """ - if pad == 0: - return pad - mod_remainder = (input_size + 2 * pad - kernel_size) % stride + if ceil_mode: + output_size = ceil((input_size - kernel_size + 2 * pad) / stride) + 1 + else: + output_size = floor((input_size - kernel_size + 2 * pad) / stride) + 1 - # No need to adjust - if mod_remainder == 0: - return pad + # Solve for post_pad from + # output_size = (input_size + pre_pad + post_pad - kernel_size) / stride + 1 + adjusted_post_pad = (output_size - 1) * stride - input_size + kernel_size - pad - return pad - mod_remainder + return adjusted_post_pad diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 66d56ce584c..e2bbfc3a8cd 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -10,7 +10,7 @@ import torch -from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineBI, @@ -26,27 +26,19 @@ input_t = Tuple[torch.Tensor] -class AvgPool2d(torch.nn.Module): - def __init__( - self, - kernel_size: int | Tuple[int, int], - stride: int | Tuple[int, int], - padding: int | Tuple[int, int], - ): - super().__init__() - self.avg_pool_2d = torch.nn.AvgPool2d( - kernel_size=kernel_size, stride=stride, padding=padding - ) - - def forward(self, x): - return self.avg_pool_2d(x) +class AvgPool2d(torch.nn.modules.AvgPool2d): + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) test_modules = { - "zeros": lambda: (AvgPool2d(4, 2, 0), (torch.zeros(1, 16, 50, 32),)), - "ones": lambda: (AvgPool2d(4, 2, 0), (torch.ones(1, 16, 50, 32),)), - "rand": lambda: (AvgPool2d(4, 2, 0), (torch.rand(1, 16, 50, 32),)), - "randn": lambda: (AvgPool2d(4, 2, 0), (torch.randn(1, 16, 50, 32),)), + "zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)), + "ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)), + "rand": lambda: (AvgPool2d(4, 2, 0, False, True, 16), (torch.rand(1, 16, 50, 32),)), + "randn": lambda: ( + AvgPool2d(4, 2, 0, divisor_override=16), + (torch.randn(1, 16, 50, 32),), + ), "kernel_3x3_stride_1_pad_1": lambda: ( AvgPool2d((3, 3), (1, 1), 1), (torch.rand(1, 16, 50, 32),), @@ -60,7 +52,7 @@ def forward(self, x): (torch.rand(1, 16, 50, 32),), ), "non_divisible_window": lambda: ( - AvgPool2d(3, 2, 1), + AvgPool2d(3, 2, 1, count_include_pad=False), (torch.rand(1, 16, 112, 112),), ), "non_divisible_window_height": lambda: ( @@ -68,9 +60,37 @@ def forward(self, x): (torch.rand(1, 16, 56, 56),), ), "non_divisible_window_width": lambda: ( - AvgPool2d(3, (1, 2), 1), + AvgPool2d(3, (1, 2), 1, count_include_pad=False), (torch.rand(1, 16, 56, 56),), ), + "non_divisible_window_ceil_mode": lambda: ( + AvgPool2d(3, 2, 1, True), + (torch.rand(1, 16, 112, 112),), + ), + "non_divisible_window_height_ceil_mode": lambda: ( + AvgPool2d(3, (2, 1), 1, True, False), + (torch.rand(1, 1, 14, 14),), + ), + "non_divisible_window_width_ceil_mode": lambda: ( + AvgPool2d(3, (1, 2), 1, True, True), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override": lambda: ( + AvgPool2d(3, 2, 1, False, False, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_count_include_pad": lambda: ( + AvgPool2d(3, 2, 1, False, True, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_ceil_mode": lambda: ( + AvgPool2d(3, 2, 1, True, False, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), + "divisor_override_ceil_mode_count_include_pad": lambda: ( + AvgPool2d(3, 2, 1, True, True, divisor_override=2), + (torch.rand(1, 1, 14, 14),), + ), } @@ -83,11 +103,8 @@ def test_avg_pool2d_tosa_MI(test_module): input_tensor, aten_op, exir_op, - run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) - if conftest.is_option_enabled("tosa_ref_model"): - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() + pipeline.run() @common.parametrize("test_module", test_modules) @@ -99,11 +116,8 @@ def test_avg_pool2d_tosa_BI(test_module): input_tensor, aten_op, exir_op, - run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), ) - if conftest.is_option_enabled("tosa_ref_model"): - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() + pipeline.run() @common.parametrize("test_module", test_modules) @@ -118,7 +132,6 @@ def test_avg_pool2d_u55_BI(test_module): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) pipeline.run() @@ -134,27 +147,25 @@ def test_avg_pool2d_u85_BI(test_module): exir_op, run_on_fvp=True, ) - pipeline.change_args("run_method_and_compare_outputs", qtol=1, atol=1, rtol=1) - pipeline.run() reject_modules = { "kernel_1x1_stride_1_pad_0": lambda: (AvgPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), "kernel_2x9_stride_1_pad_1": lambda: ( - AvgPool2d((2, 9), 1, 1), + AvgPool2d((2, 9), 1, 1, count_include_pad=False), torch.rand(1, 16, 5, 32), ), "kernel_1x4_stride_0_pad_0": lambda: ( - AvgPool2d(1, 4, 0), + AvgPool2d(1, 4, 0, count_include_pad=False), torch.rand(1, 10, 10, 10), ), "kernel_1x257_stride_1_pad_0_large": lambda: ( - AvgPool2d((1, 257), 1, 0), + AvgPool2d((1, 257), 1, 0, count_include_pad=False), torch.rand(1, 16, 5, 300), ), "kernel_800x90_stride_1_pad_0_extreme": lambda: ( - AvgPool2d((800, 90), 1, 0), + AvgPool2d((800, 90), 1, 0, count_include_pad=False), torch.rand(1, 16, 850, 100), ), } diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 1065573049e..55340a565e5 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -21,12 +21,24 @@ test_data_suite = { # (test_name, test_data, [kernel_size, stride, padding]) - "zeros": lambda: (torch.zeros(1, 1, 4, 8), [2, 2, 1]), + "zeros": lambda: (torch.zeros(1, 1, 4, 8), [(4, 6), 2, (2, 0)]), "ones": lambda: (torch.ones(1, 16, 50, 32), [4, 2, 0]), "rand": lambda: (torch.rand(1, 16, 52, 16), [4, 3, 0]), "non_divisible": lambda: (torch.rand(1, 16, 112, 112), [3, 2, 1]), "non_divisible_window_height": lambda: (torch.rand(1, 16, 56, 56), [3, (2, 1), 1]), "non_divisible_window_width": lambda: (torch.rand(1, 16, 56, 56), [3, (1, 2), 1]), + "non_divisible_ceil_mode": lambda: ( + torch.rand(1, 16, 112, 112), + [3, 2, 1, 1, True], + ), + "non_divisible_window_height_ceil_mode": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (2, 1), 1, 1, True], + ), + "non_divisible_window_width_ceil_mode": lambda: ( + torch.rand(1, 16, 56, 56), + [3, (1, 2), 1, 1, True], + ), } test_data_suite_mult_batches = { @@ -61,6 +73,7 @@ def __init__( stride: int | Tuple[int, int], padding: int | Tuple[int, int], dilation: int | Tuple[int, int] = 1, + ceil_mode: bool = False, ): super().__init__() self.max_pool_2d = torch.nn.MaxPool2d( @@ -68,6 +81,7 @@ def __init__( stride=stride, padding=padding, dilation=dilation, + ceil_mode=ceil_mode, ) def forward(self, x):