diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 983aa091eec..27ddb95637b 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -2,6 +2,13 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +"""Provide Ethos-U55 specific operator support checks. + +Contains dtype validation, explicit unsupported-op filtering, and shape/ +permutation constraints for view and permute operations when targeting the +Ethos-U55 subset of TOSA. + +""" # pyre-unsafe @@ -21,6 +28,19 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None: + """Return an inferred dtype for a node when possible. + + Uses fake tensor metadata and nearby quantize/dequantize nodes to infer the + integer dtype used by the operator. Returns ``None`` when the dtype cannot + be determined reliably. + + Args: + node (fx.Node): FX node to inspect. + + Returns: + torch.dtype | None: Inferred dtype or ``None`` if unknown. + + """ dtype = get_first_fake_tensor(node).dtype if not dtype.is_floating_point: return dtype @@ -34,8 +54,23 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None: class EthosU55DtypeSupport(OperatorSupportBase): + """Validate dtypes for U55-supported operators. + + Ensures operators use a supported integer dtype according to U55 + constraints, with specific rules for convolution, matmul, and table ops. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter @@ -52,7 +87,20 @@ def __init__(self, reporter: WhyNoPartitionReporter): def is_node_supported( # noqa: C901 self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if the node uses supported dtypes. + Applies per-operator dtype rules for U55, including specialized input + and weight constraints for convolution and int8-only checks for table + operations and matmul variants. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + + """ dtype = _try_determine_dtype(node) if dtype is None: # If we couldn't determine dtype, just return ok. @@ -112,10 +160,12 @@ def is_node_supported( # noqa: C901 class EthosU55NotSupported(OperatorSupportBase): - """ - Certain operators are not supported on U55. These are listed in `unsupported_ops`. - The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious. - For unimplemented operators, this is the anticipated mapping, and it might be incorrect. + """Reject operators not supported by Ethos-U55. + + The ``unsupported_ops`` list contains aten ops that either map to TOSA + operators the U55 cannot run or remain unimplemented. The mapping comments + capture expected TOSA equivalents when not obvious. + """ unsupported_ops = [ @@ -165,12 +215,27 @@ class EthosU55NotSupported(OperatorSupportBase): ] def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ self.reporter = reporter def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return False for nodes explicitly unsupported on U55. + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: False if ``node.target`` is in ``unsupported_ops``; else True. + + """ if node.target in self.unsupported_ops: self.reporter.report_reject(node, "Op is not supported on U55.") return False @@ -182,12 +247,37 @@ def is_node_supported( class EthosU55ViewCheck(OperatorSupportBase): + """Validate view/select shapes and dtypes for U55. + + Performs lightweight checks on output shape rank and product constraints, + with awareness that transposes may be inserted around view/select during + lowering to channels-last. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter def axes_product(self, nhwc_shape: shape_t) -> int: + """Return the product of all axes in ``nhwc_shape``. + + Args: + nhwc_shape (list[int]): Shape in NHWC order. + + Returns: + int: Product of the axis sizes. + + """ product = 1 for axes in nhwc_shape: product *= axes @@ -197,26 +287,27 @@ def axes_product(self, nhwc_shape: shape_t) -> int: def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - """ - Check whether a given view node is supported on U55. + """Check whether a given view/select node is U55-supported. Currently only checks dtypes and product of axes. - It is not the view operator itself that is not supported on U55. In order for the - view operator to be compatible with the channels-last format of TosaBackend, - transposes may need to be inserted before and after the view op. If that happens - and that transpose operator does not adhere to the limitations then it will - result in the following error: + It is not the view operator itself that is not supported on U55. In + order for the view operator to be compatible with the channels-last + format of TosaBackend, transposes may need to be inserted before and + after the view op. If that happens and that transpose operator does not + adhere to the limitations then it will result in the following error: CPU performance estimation for "Transpose" not implemented. ... CPU operations are not supported for GraphAPI input Args: - node: The FX node representing the view_copy operator. + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node for ``view_copy`` or ``select``. Returns: - False if the operator is not support and True if it is supported. + bool: False if rejected by constraints; otherwise, True. + """ # Select decomposes into squeeze, which in turn becomes a view. Therefore, # perform the same check on select operators as view operators. @@ -279,14 +370,40 @@ def is_node_supported( class EthosU55TransposeCheck(OperatorSupportBase): + """Validate permute nodes against U55 reshape/transpose limits. + + Applies dtype- and rank-specific constraints to permutations. Tests both + NCHW and NHWC interpretations for rank-3/4 shapes since dim order is unknown + at partition time. + + Attributes: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ def __init__(self, reporter: WhyNoPartitionReporter): + """Initialize the check with a reporter. + + Args: + reporter (WhyNoPartitionReporter): Reporter for rejection reasons. + + """ super().__init__() self.reporter = reporter def _pad_to_rank_4( self, shape: shape_t, permutation: list[int] ) -> tuple[shape_t, shape_t]: + """Pad shape/permutation to rank 4 by prepending ones/indices. + + Args: + shape (list[int]): Original shape. + permutation (list[int]): Original permutation indices. + + Returns: + tuple[list[int], list[int]]: Padded shape and permutation. + + """ diff = 4 - len(shape) padded_shape = [1] * diff + shape for i in range(len(permutation)): @@ -295,6 +412,15 @@ def _pad_to_rank_4( return padded_shape, padded_permutation def axes_product(self, nhwc_shape: shape_t) -> int: + """Return the product of all axes in ``nhwc_shape``. + + Args: + nhwc_shape (list[int]): Shape in NHWC order. + + Returns: + int: Product of the axis sizes. + + """ product = 1 for axes in nhwc_shape: product *= axes @@ -303,7 +429,7 @@ def axes_product(self, nhwc_shape: shape_t) -> int: def _permute_constraint_i8_i16( self, nhwc_shape: list[int], permutation: list[int] ) -> bool: - """Returns True if the constraints are ok.""" + """Return True if permutation meets i8/i16 constraints.""" N, H, W, C = nhwc_shape match permutation: case (0, 1, 2, 3): # NHWC -> NHWC @@ -316,7 +442,7 @@ def _permute_constraint_i8_i16( def _permute_constraint_i32( self, nhwc_shape: list[int], permutation: list[int] ) -> bool: - """Returns True if the constraints are ok.""" + """Return True if permutation meets i32 constraints.""" N, H, W, C = nhwc_shape match permutation: case (0, 1, 2, 3): # NHWC -> NHWC @@ -329,6 +455,7 @@ def _permute_constraint_i32( return False def _permute_constraint(self, shape, permutation, dtype): + """Return True if permutation meets dtype-specific constraints.""" if dtype in (torch.int8, torch.int16): return self._permute_constraint_i8_i16(shape, permutation) if dtype == torch.int32: @@ -338,7 +465,19 @@ def _permute_constraint(self, shape, permutation, dtype): def is_node_supported( self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: + """Return True if a permute node satisfies U55 constraints. + + Tests both NCHW and NHWC interpretations for rank-3/4 shapes, and + applies dtype-specific limits to shapes and permutations. + + Args: + submodules (typing.Mapping[str, torch.nn.Module]): Exported modules. + node (fx.Node): FX node to check. + + Returns: + bool: True if supported; otherwise, False. + """ if not node.target == exir_ops.edge.aten.permute_copy.default: return True