2
2
#
3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
+ """Provide TOSA support checks for 2D pooling.
6
+
7
+ Validate ``avg_pool2d`` and ``max_pool2d_with_indices`` against U55 profile
8
+ constraints including kernel size, stride, padding, and dimensionality.
9
+
10
+ """
5
11
6
12
from typing import cast
7
13
20
26
21
27
22
28
def kernel_check (kernel : tuple [int , int ]) -> bool :
29
+ """Check if kernel size is within U55 constraints.
30
+
31
+ Checks that ``kernel_x * kernel_y`` is in ``[1, 65536]`` and
32
+ ``kernel_y`` is in ``[1, 256]`` as required by the U55 profile.
33
+
34
+ Args:
35
+ kernel (tuple[int, int]): Kernel height and width ``(kh, kw)``.
36
+
37
+ Returns:
38
+ bool: True if the kernel passes validation.
39
+
40
+ """
23
41
if not (1 <= kernel [0 ] * kernel [1 ] <= 65536 ):
24
42
return False
25
43
return 1 <= kernel [1 ] <= 256
26
44
27
45
28
46
def stride_check (strides : tuple [int , int ]) -> bool :
47
+ """Check if strides are within U55 constraints.
48
+
49
+ Args:
50
+ strides (tuple[int, int]): Vertical and horizontal strides.
51
+
52
+ Returns:
53
+ bool: True if each stride is in ``[1, 3]``.
54
+
55
+ """
29
56
return all (1 <= stride <= 3 for stride in strides )
30
57
31
58
32
59
def dim_check (shape = torch .Size ) -> bool :
60
+ """Check if non-batch dims are within U55 constraints.
61
+
62
+ Verifies that all dimensions except batch are in ``[1, 65536]``.
63
+
64
+ Args:
65
+ shape (torch.Size): Input tensor shape.
66
+
67
+ Returns:
68
+ bool: True if all checked dimensions pass.
69
+
70
+ """
33
71
check = True
34
72
for dim in shape [1 :]:
35
73
check &= 1 <= dim <= 65536
@@ -38,6 +76,13 @@ def dim_check(shape=torch.Size) -> bool:
38
76
39
77
@register_tosa_support_check
40
78
class AvgPool2dSupported (SupportedTOSAOperatorCheck ):
79
+ """Provide TOSA support checks for ``aten.avg_pool2d``.
80
+
81
+ Applies additional constraints when targeting the U55 subset, including
82
+ limits on kernel size, stride, padding behavior, and tensor ranks.
83
+
84
+ """
85
+
41
86
targets = [
42
87
exir_ops .edge .aten .avg_pool2d .default ,
43
88
]
@@ -48,6 +93,12 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
48
93
]
49
94
50
95
def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
96
+ """Return True if ``avg_pool2d`` satisfies U55 constraints.
97
+
98
+ Computes the effective TOSA padding (depending on ``count_include_pad``
99
+ and ``divisor_override``) and validates kernel, stride, and shape limits.
100
+
101
+ """
51
102
if not tosa_spec .is_U55_subset :
52
103
return True
53
104
@@ -115,6 +166,13 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
115
166
116
167
@register_tosa_support_check
117
168
class MaxPool2dSupported (SupportedTOSAOperatorCheck ):
169
+ """Provide TOSA support checks for ``aten.max_pool2d_with_indices``.
170
+
171
+ Applies additional constraints when targeting the U55 subset, including
172
+ limits on kernel size, stride, and tensor ranks.
173
+
174
+ """
175
+
118
176
targets = [
119
177
exir_ops .edge .aten .max_pool2d_with_indices .default ,
120
178
]
@@ -125,6 +183,9 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
125
183
]
126
184
127
185
def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
186
+ """Return True if ``max_pool2d_with_indices`` satisfies U55
187
+ constraints.
188
+ """
128
189
if not tosa_spec .is_U55_subset :
129
190
return True
130
191
0 commit comments