Skip to content

Commit 4a644b7

Browse files
Arm backend: Add docstrings for operator_support/pool_2d_support.py (#14683)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent ffd27ca commit 4a644b7

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

backends/arm/operator_support/pool_2d_support.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide TOSA support checks for 2D pooling.
6+
7+
Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
8+
constraints including kernel size, stride, padding, and dimensionality.
9+
10+
"""
511

612
from typing import cast
713

@@ -20,16 +26,48 @@
2026

2127

2228
def kernel_check(kernel: tuple[int, int]) -> bool:
29+
"""Check if kernel size is within U55 constraints.
30+
31+
Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
32+
``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.
33+
34+
Args:
35+
kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.
36+
37+
Returns:
38+
bool: True if the kernel passes validation.
39+
40+
"""
2341
if not (1 <= kernel[0] * kernel[1] <= 65536):
2442
return False
2543
return 1 <= kernel[1] <= 256
2644

2745

2846
def stride_check(strides: tuple[int, int]) -> bool:
47+
"""Check if strides are within U55 constraints.
48+
49+
Args:
50+
strides (tuple[int, int]): Vertical and horizontal strides.
51+
52+
Returns:
53+
bool: True if each stride is in ``[1, 3]``.
54+
55+
"""
2956
return all(1 <= stride <= 3 for stride in strides)
3057

3158

3259
def dim_check(shape=torch.Size) -> bool:
60+
"""Check if non-batch dims are within U55 constraints.
61+
62+
Verifies that all dimensions except batch are in ``[1, 65536]``.
63+
64+
Args:
65+
shape (torch.Size): Input tensor shape.
66+
67+
Returns:
68+
bool: True if all checked dimensions pass.
69+
70+
"""
3371
check = True
3472
for dim in shape[1:]:
3573
check &= 1 <= dim <= 65536
@@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:
3876

3977
@register_tosa_support_check
4078
class AvgPool2dSupported(SupportedTOSAOperatorCheck):
79+
"""Provide TOSA support checks for ``aten.avg_pool2d``.
80+
81+
Applies additional constraints when targeting the U55 subset, including
82+
limits on kernel size, stride, padding behavior, and tensor ranks.
83+
84+
"""
85+
4186
targets = [
4287
exir_ops.edge.aten.avg_pool2d.default,
4388
]
@@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4893
]
4994

5095
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
96+
"""Return True if ``avg_pool2d`` satisfies U55 constraints.
97+
98+
Computes the effective TOSA padding (depending on ``count_include_pad``
99+
and ``divisor_override``) and validates kernel, stride, and shape limits.
100+
101+
"""
51102
if not tosa_spec.is_U55_subset:
52103
return True
53104

@@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
115166

116167
@register_tosa_support_check
117168
class MaxPool2dSupported(SupportedTOSAOperatorCheck):
169+
"""Provide TOSA support checks for ``aten.max_pool2d_with_indices``.
170+
171+
Applies additional constraints when targeting the U55 subset, including
172+
limits on kernel size, stride, and tensor ranks.
173+
174+
"""
175+
118176
targets = [
119177
exir_ops.edge.aten.max_pool2d_with_indices.default,
120178
]
@@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
125183
]
126184

127185
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
186+
"""Return True if ``max_pool2d_with_indices`` satisfies U55
187+
constraints.
188+
"""
128189
if not tosa_spec.is_U55_subset:
129190
return True
130191

0 commit comments

Comments
 (0)