diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index 7aa35a721b6..c1dd143a4fc 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -26,8 +26,8 @@ def stride_check(strides: tuple[int, int]) -> bool: def dim_check(shape=torch.Size) -> bool: - check = shape[0] == 1 - for dim in shape: + check = True + for dim in shape[1:]: check &= 1 <= dim <= 65536 return check @@ -59,7 +59,7 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): if not kernel_check(kernel): return False - return dim_check(shape) and stride_check(stride) + return dim_check(shape) and shape[0] == 1 and stride_check(stride) @register_tosa_support_check diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 6d6e0b8be5c..a31c12be3a0 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -232,8 +232,24 @@ def test_maxpool2d_tosa_u85_BI_mult_batches( if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,)) + @parameterized.expand(test_data_suite_mult_batches) + @pytest.mark.corstone_fvp + @conftest.expectedFailureOnFVP # TODO: MLETORCH-433 + def test_maxpool2d_tosa_u55_BI_mult_batches( + self, + test_name: str, + test_data: torch.Tensor, + model_params: int | Tuple[int, int], + ): + tester = self._test_maxpool2d_tosa_ethos_BI_pipeline( + self.MaxPool2d(*model_params), + common.get_u55_compile_spec(), + (test_data,), + ) + if conftest.is_option_enabled("corstone_fvp"): + tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,)) + reject_data_suite = [ - (MaxPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)), (MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), (MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), (MaxPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),