From de56339344cf9ef1594fc6db888b9bd21416187a Mon Sep 17 00:00:00 2001 From: Sebastian Larsson Date: Thu, 11 Sep 2025 16:40:24 +0200 Subject: [PATCH] Arm backend: Add docstrings for operator_support/pool_2d_support.py Change-Id: I8d565b59c880a5a4652b93c79180330823b0af9e Signed-off-by: Sebastian Larsson --- .../arm/operator_support/pool_2d_support.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index ff453741f1f..c0428e45e03 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide TOSA support checks for 2D pooling. + +Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile +constraints including kernel size, stride, padding, and dimensionality. + +""" from typing import cast @@ -20,16 +26,48 @@ def kernel_check(kernel: tuple[int, int]) -> bool: + """Check if kernel size is within U55 constraints. + + Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and + ``kernel_y`` is in ``[1, 256]`` as required by the U55 profile. + + Args: + kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``. + + Returns: + bool: True if the kernel passes validation. + + """ if not (1 <= kernel[0] * kernel[1] <= 65536): return False return 1 <= kernel[1] <= 256 def stride_check(strides: tuple[int, int]) -> bool: + """Check if strides are within U55 constraints. + + Args: + strides (tuple[int, int]): Vertical and horizontal strides. + + Returns: + bool: True if each stride is in ``[1, 3]``. + + """ return all(1 <= stride <= 3 for stride in strides) def dim_check(shape=torch.Size) -> bool: + """Check if non-batch dims are within U55 constraints. + + Verifies that all dimensions except batch are in ``[1, 65536]``. + + Args: + shape (torch.Size): Input tensor shape. + + Returns: + bool: True if all checked dimensions pass. + + """ check = True for dim in shape[1:]: check &= 1 <= dim <= 65536 @@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool: @register_tosa_support_check class AvgPool2dSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support checks for ``aten.avg_pool2d``. + + Applies additional constraints when targeting the U55 subset, including + limits on kernel size, stride, padding behavior, and tensor ranks. + + """ + targets = [ exir_ops.edge.aten.avg_pool2d.default, ] @@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + """Return True if ``avg_pool2d`` satisfies U55 constraints. + + Computes the effective TOSA padding (depending on ``count_include_pad`` + and ``divisor_override``) and validates kernel, stride, and shape limits. + + """ if not tosa_spec.is_U55_subset: return True @@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): @register_tosa_support_check class MaxPool2dSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support checks for ``aten.max_pool2d_with_indices``. + + Applies additional constraints when targeting the U55 subset, including + limits on kernel size, stride, and tensor ranks. + + """ + targets = [ exir_ops.edge.aten.max_pool2d_with_indices.default, ] @@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + """Return True if ``max_pool2d_with_indices`` satisfies U55 + constraints. + """ if not tosa_spec.is_U55_subset: return True