22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Provide TOSA support checks for 2D pooling.
6+
7+ Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
8+ constraints including kernel size, stride, padding, and dimensionality.
9+
10+ """
511
612from typing import cast
713
2026
2127
2228def kernel_check (kernel : tuple [int , int ]) -> bool :
29+ """Check if kernel size is within U55 constraints.
30+
31+ Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
32+ ``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.
33+
34+ Args:
35+ kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.
36+
37+ Returns:
38+ bool: True if the kernel passes validation.
39+
40+ """
2341 if not (1 <= kernel [0 ] * kernel [1 ] <= 65536 ):
2442 return False
2543 return 1 <= kernel [1 ] <= 256
2644
2745
2846def stride_check (strides : tuple [int , int ]) -> bool :
47+ """Check if strides are within U55 constraints.
48+
49+ Args:
50+ strides (tuple[int, int]): Vertical and horizontal strides.
51+
52+ Returns:
53+ bool: True if each stride is in ``[1, 3]``.
54+
55+ """
2956 return all (1 <= stride <= 3 for stride in strides )
3057
3158
3259def dim_check (shape = torch .Size ) -> bool :
60+ """Check if non-batch dims are within U55 constraints.
61+
62+ Verifies that all dimensions except batch are in ``[1, 65536]``.
63+
64+ Args:
65+ shape (torch.Size): Input tensor shape.
66+
67+ Returns:
68+ bool: True if all checked dimensions pass.
69+
70+ """
3371 check = True
3472 for dim in shape [1 :]:
3573 check &= 1 <= dim <= 65536
@@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:
3876
3977@register_tosa_support_check
4078class AvgPool2dSupported (SupportedTOSAOperatorCheck ):
79+ """Provide TOSA support checks for ``aten.avg_pool2d``.
80+
81+ Applies additional constraints when targeting the U55 subset, including
82+ limits on kernel size, stride, padding behavior, and tensor ranks.
83+
84+ """
85+
4186 targets = [
4287 exir_ops .edge .aten .avg_pool2d .default ,
4388 ]
@@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4893 ]
4994
5095 def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
96+ """Return True if ``avg_pool2d`` satisfies U55 constraints.
97+
98+ Computes the effective TOSA padding (depending on ``count_include_pad``
99+ and ``divisor_override``) and validates kernel, stride, and shape limits.
100+
101+ """
51102 if not tosa_spec .is_U55_subset :
52103 return True
53104
@@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
115166
116167@register_tosa_support_check
117168class MaxPool2dSupported (SupportedTOSAOperatorCheck ):
169+ """Provide TOSA support checks for ``aten.max_pool2d_with_indices``.
170+
171+ Applies additional constraints when targeting the U55 subset, including
172+ limits on kernel size, stride, and tensor ranks.
173+
174+ """
175+
118176 targets = [
119177 exir_ops .edge .aten .max_pool2d_with_indices .default ,
120178 ]
@@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
125183 ]
126184
127185 def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
186+ """Return True if ``max_pool2d_with_indices`` satisfies U55
187+ constraints.
188+ """
128189 if not tosa_spec .is_U55_subset :
129190 return True
130191
0 commit comments