22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
5+ """Declare operator support for ``aten.convolution`` in TOSA.
6+
7+ Provide general checks and hardware-specific constraints (e.g., U55 subset) for
8+ convolution nodes prior to delegation to the TOSA backend.
9+
10+ """
511
612from typing import cast
713
1824
1925@register_tosa_support_check
2026class ConvolutionSupported (SupportedTOSAOperatorCheck ):
27+ """Provide TOSA support check for convolutions."""
28+
2129 targets = [exir_ops .edge .aten .convolution .default ]
2230
2331 tosa_specs = [
2432 TosaSpecification .create_from_string ("TOSA-1.0+INT" ),
2533 TosaSpecification .create_from_string ("TOSA-1.0+FP" ),
2634 ]
2735
28- def is_node_tosa_supported (self , node : fx .Node , tosa_spec : TosaSpecification ):
36+ def is_node_tosa_supported (
37+ self , node : fx .Node , tosa_spec : TosaSpecification
38+ ) -> bool :
39+ """Return True if the node is supported by TOSA.
2940
41+ Reject transposed convolutions and convolutions with non-zero output
42+ padding. Apply additional hardware-specific constraints for U55.
43+
44+ """
3045 # Not implemented
3146 transposed = cast (bool , node .args [6 ])
3247 output_padding = cast (list [int ], node .args [7 ])
@@ -46,9 +61,19 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
4661 else :
4762 return True
4863
49- def _is_node_supported_u55 (self , node : fx .Node ):
50- """Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)"""
64+ def _is_node_supported_u55 (self , node : fx .Node ) -> bool :
65+ """Enforce Ethos-U55-specific constraints (Vela 4.2.0).
66+
67+ Check channel dimensions, kernel sizes, and stride/pad/dilation
68+ combinations permitted on U55.
5169
70+ Args:
71+ node (fx.Node): Convolution node to validate.
72+
73+ Returns:
74+ bool: True if supported; otherwise, False.
75+
76+ """
5277 shape_in = cast (torch .Tensor , node .all_input_nodes [0 ].meta ["val" ]).shape
5378 shape_out = node .meta ["val" ].shape
5479 kernel = cast (fx .Node , node .args [1 ]).meta ["val" ].shape
@@ -98,13 +123,17 @@ def _is_node_supported_u55(self, node: fx.Node):
98123 return True
99124
100125 def _stride_condition (self , node : fx .Node ) -> bool :
101- """This condition is somewhat complex but boils down
102- to not supporting stride > 3, unless we have some special conditions.
103- This condition is a simplified, relaxed version of the hardware constraint,
104- since the actual constraint requires information not available
105- here (without a lot of work).
126+ """Check a simplified stride/padding/dilation constraint.
127+
128+ Disallow strides greater than 3 unless there is no padding and the
129+ dilation is 1. For 3D convolutions, enforce ``stride_z <= 1``.
130+
131+ Args:
132+ node (fx.Node): Convolution node to evaluate.
133+
134+ Returns:
135+ bool: True if the condition is satisfied.
106136
107- This means that we might accept ops that are not actually supported.
108137 """
109138 strides = cast (list [int ], node .args [3 ])
110139 has_padding = any (pad > 0 for pad in cast (list [int ], node .args [4 ]))
0 commit comments