Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide TOSA support checks for 2D pooling.

Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
constraints including kernel size, stride, padding, and dimensionality.

"""

from typing import cast

Expand All @@ -20,16 +26,48 @@


def kernel_check(kernel: tuple[int, int]) -> bool:
"""Check if kernel size is within U55 constraints.

Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.

Args:
kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.

Returns:
bool: True if the kernel passes validation.

"""
if not (1 <= kernel[0] * kernel[1] <= 65536):
return False
return 1 <= kernel[1] <= 256


def stride_check(strides: tuple[int, int]) -> bool:
"""Check if strides are within U55 constraints.

Args:
strides (tuple[int, int]): Vertical and horizontal strides.

Returns:
bool: True if each stride is in ``[1, 3]``.

"""
return all(1 <= stride <= 3 for stride in strides)


def dim_check(shape=torch.Size) -> bool:
"""Check if non-batch dims are within U55 constraints.

Verifies that all dimensions except batch are in ``[1, 65536]``.

Args:
shape (torch.Size): Input tensor shape.

Returns:
bool: True if all checked dimensions pass.

"""
check = True
for dim in shape[1:]:
check &= 1 <= dim <= 65536
Expand All @@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:

@register_tosa_support_check
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
"""Provide TOSA support checks for ``aten.avg_pool2d``.

Applies additional constraints when targeting the U55 subset, including
limits on kernel size, stride, padding behavior, and tensor ranks.

"""

targets = [
exir_ops.edge.aten.avg_pool2d.default,
]
Expand All @@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
"""Return True if ``avg_pool2d`` satisfies U55 constraints.

Computes the effective TOSA padding (depending on ``count_include_pad``
and ``divisor_override``) and validates kernel, stride, and shape limits.

"""
if not tosa_spec.is_U55_subset:
return True

Expand Down Expand Up @@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

@register_tosa_support_check
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
"""Provide TOSA support checks for ``aten.max_pool2d_with_indices``.

Applies additional constraints when targeting the U55 subset, including
limits on kernel size, stride, and tensor ranks.

"""

targets = [
exir_ops.edge.aten.max_pool2d_with_indices.default,
]
Expand All @@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
"""Return True if ``max_pool2d_with_indices`` satisfies U55
constraints.
"""
if not tosa_spec.is_U55_subset:
return True

Expand Down
Loading