diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 6e9d3b3528e..f335c5046f5 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -2,6 +2,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Declare operator support for ``aten.convolution`` in TOSA. + +Provide general checks and hardware-specific constraints (e.g., U55 subset) for +convolution nodes prior to delegation to the TOSA backend. + +""" from typing import cast @@ -18,6 +24,8 @@ @register_tosa_support_check class ConvolutionSupported(SupportedTOSAOperatorCheck): + """Provide TOSA support check for convolutions.""" + targets = [exir_ops.edge.aten.convolution.default] tosa_specs = [ @@ -25,8 +33,15 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck): TosaSpecification.create_from_string("TOSA-1.0+FP"), ] - def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): + def is_node_tosa_supported( + self, node: fx.Node, tosa_spec: TosaSpecification + ) -> bool: + """Return True if the node is supported by TOSA. + Reject transposed convolutions and convolutions with non-zero output + padding. Apply additional hardware-specific constraints for U55. + + """ # Not implemented transposed = cast(bool, node.args[6]) output_padding = cast(list[int], node.args[7]) @@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): else: return True - def _is_node_supported_u55(self, node: fx.Node): - """Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)""" + def _is_node_supported_u55(self, node: fx.Node) -> bool: + """Enforce Ethos-U55-specific constraints (Vela 4.2.0). + + Check channel dimensions, kernel sizes, and stride/pad/dilation + combinations permitted on U55. + Args: + node (fx.Node): Convolution node to validate. + + Returns: + bool: True if supported; otherwise, False. + + """ shape_in = cast(torch.Tensor, node.all_input_nodes[0].meta["val"]).shape shape_out = node.meta["val"].shape kernel = cast(fx.Node, node.args[1]).meta["val"].shape @@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node): return True def _stride_condition(self, node: fx.Node) -> bool: - """This condition is somewhat complex but boils down - to not supporting stride > 3, unless we have some special conditions. - This condition is a simplified, relaxed version of the hardware constraint, - since the actual constraint requires information not available - here (without a lot of work). + """Check a simplified stride/padding/dilation constraint. + + Disallow strides greater than 3 unless there is no padding and the + dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``. + + Args: + node (fx.Node): Convolution node to evaluate. + + Returns: + bool: True if the condition is satisfied. - This means that we might accept ops that are not actually supported. """ strides = cast(list[int], node.args[3]) has_padding = any(pad > 0 for pad in cast(list[int], node.args[4]))