Skip to content

Commit 0cd8256

Browse files
Arm backend: Add docstrings for operator_support/convolution_support.py (pytorch#14684)
Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 0081bef commit 0cd8256

File tree

1 file changed

+38
-9
lines changed

1 file changed

+38
-9
lines changed

backends/arm/operator_support/convolution_support.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Declare operator support for ``aten.convolution`` in TOSA.
6+
7+
Provide general checks and hardware-specific constraints (e.g., U55 subset) for
8+
convolution nodes prior to delegation to the TOSA backend.
9+
10+
"""
511

612
from typing import cast
713

@@ -18,15 +24,24 @@
1824

1925
@register_tosa_support_check
2026
class ConvolutionSupported(SupportedTOSAOperatorCheck):
27+
"""Provide TOSA support check for convolutions."""
28+
2129
targets = [exir_ops.edge.aten.convolution.default]
2230

2331
tosa_specs = [
2432
TosaSpecification.create_from_string("TOSA-1.0+INT"),
2533
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2634
]
2735

28-
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
36+
def is_node_tosa_supported(
37+
self, node: fx.Node, tosa_spec: TosaSpecification
38+
) -> bool:
39+
"""Return True if the node is supported by TOSA.
2940
41+
Reject transposed convolutions and convolutions with non-zero output
42+
padding. Apply additional hardware-specific constraints for U55.
43+
44+
"""
3045
# Not implemented
3146
transposed = cast(bool, node.args[6])
3247
output_padding = cast(list[int], node.args[7])
@@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4661
else:
4762
return True
4863

49-
def _is_node_supported_u55(self, node: fx.Node):
50-
"""Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
64+
def _is_node_supported_u55(self, node: fx.Node) -> bool:
65+
"""Enforce Ethos-U55-specific constraints (Vela 4.2.0).
66+
67+
Check channel dimensions, kernel sizes, and stride/pad/dilation
68+
combinations permitted on U55.
5169
70+
Args:
71+
node (fx.Node): Convolution node to validate.
72+
73+
Returns:
74+
bool: True if supported; otherwise, False.
75+
76+
"""
5277
shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape
5378
shape_out = node.meta["val"].shape
5479
kernel = cast(fx.Node, node.args[1]).meta["val"].shape
@@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node):
98123
return True
99124

100125
def _stride_condition(self, node: fx.Node) -> bool:
101-
"""This condition is somewhat complex but boils down
102-
to not supporting stride > 3, unless we have some special conditions.
103-
This condition is a simplified, relaxed version of the hardware constraint,
104-
since the actual constraint requires information not available
105-
here (without a lot of work).
126+
"""Check a simplified stride/padding/dilation constraint.
127+
128+
Disallow strides greater than 3 unless there is no padding and the
129+
dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``.
130+
131+
Args:
132+
node (fx.Node): Convolution node to evaluate.
133+
134+
Returns:
135+
bool: True if the condition is satisfied.
106136
107-
This means that we might accept ops that are not actually supported.
108137
"""
109138
strides = cast(list[int], node.args[3])
110139
has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))

0 commit comments

Comments
 (0)