Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 154 additions & 15 deletions backends/arm/operator_support/ethos_u55_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Provide Ethos-U55 specific operator support checks.

Contains dtype validation, explicit unsupported-op filtering, and shape/
permutation constraints for view and permute operations when targeting the
Ethos-U55 subset of TOSA.

"""

# pyre-unsafe

Expand All @@ -21,6 +28,19 @@


def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:
"""Return an inferred dtype for a node when possible.

Uses fake tensor metadata and nearby quantize/dequantize nodes to infer the
integer dtype used by the operator. Returns ``None`` when the dtype cannot
be determined reliably.

Args:
node (fx.Node): FX node to inspect.

Returns:
torch.dtype | None: Inferred dtype or ``None`` if unknown.

"""
dtype = get_first_fake_tensor(node).dtype
if not dtype.is_floating_point:
return dtype
Expand All @@ -34,8 +54,23 @@ def _try_determine_dtype(node: fx.Node) -> torch.dtype | None:


class EthosU55DtypeSupport(OperatorSupportBase):
"""Validate dtypes for U55-supported operators.

Ensures operators use a supported integer dtype according to U55
constraints, with specific rules for convolution, matmul, and table ops.

Attributes:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""

def __init__(self, reporter: WhyNoPartitionReporter):
"""Initialize the check with a reporter.

Args:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""
super().__init__()
self.reporter = reporter

Expand All @@ -52,7 +87,20 @@ def __init__(self, reporter: WhyNoPartitionReporter):
def is_node_supported( # noqa: C901
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""Return True if the node uses supported dtypes.

Applies per-operator dtype rules for U55, including specialized input
and weight constraints for convolution and int8-only checks for table
operations and matmul variants.

Args:
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
node (fx.Node): FX node to check.

Returns:
bool: True if supported; otherwise, False.

"""
dtype = _try_determine_dtype(node)
if dtype is None:
# If we couldn't determine dtype, just return ok.
Expand Down Expand Up @@ -112,10 +160,12 @@ def is_node_supported( # noqa: C901


class EthosU55NotSupported(OperatorSupportBase):
"""
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
"""Reject operators not supported by Ethos-U55.

The ``unsupported_ops`` list contains aten ops that either map to TOSA
operators the U55 cannot run or remain unimplemented. The mapping comments
capture expected TOSA equivalents when not obvious.

"""

unsupported_ops = [
Expand Down Expand Up @@ -165,12 +215,27 @@ class EthosU55NotSupported(OperatorSupportBase):
]

def __init__(self, reporter: WhyNoPartitionReporter):
"""Initialize the check with a reporter.

Args:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""
self.reporter = reporter

def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""Return False for nodes explicitly unsupported on U55.

Args:
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
node (fx.Node): FX node to check.

Returns:
bool: False if ``node.target`` is in ``unsupported_ops``; else True.

"""
if node.target in self.unsupported_ops:
self.reporter.report_reject(node, "Op is not supported on U55.")
return False
Expand All @@ -182,12 +247,37 @@ def is_node_supported(


class EthosU55ViewCheck(OperatorSupportBase):
"""Validate view/select shapes and dtypes for U55.

Performs lightweight checks on output shape rank and product constraints,
with awareness that transposes may be inserted around view/select during
lowering to channels-last.

Attributes:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""

def __init__(self, reporter: WhyNoPartitionReporter):
"""Initialize the check with a reporter.

Args:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""
super().__init__()
self.reporter = reporter

def axes_product(self, nhwc_shape: shape_t) -> int:
"""Return the product of all axes in ``nhwc_shape``.

Args:
nhwc_shape (list[int]): Shape in NHWC order.

Returns:
int: Product of the axis sizes.

"""
product = 1
for axes in nhwc_shape:
product *= axes
Expand All @@ -197,26 +287,27 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""
Check whether a given view node is supported on U55.
"""Check whether a given view/select node is U55-supported.

Currently only checks dtypes and product of axes.

It is not the view operator itself that is not supported on U55. In order for the
view operator to be compatible with the channels-last format of TosaBackend,
transposes may need to be inserted before and after the view op. If that happens
and that transpose operator does not adhere to the limitations then it will
result in the following error:
It is not the view operator itself that is not supported on U55. In
order for the view operator to be compatible with the channels-last
format of TosaBackend, transposes may need to be inserted before and
after the view op. If that happens and that transpose operator does not
adhere to the limitations then it will result in the following error:

CPU performance estimation for "Transpose" not implemented.
...
CPU operations are not supported for GraphAPI input

Args:
node: The FX node representing the view_copy operator.
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
node (fx.Node): FX node for ``view_copy`` or ``select``.

Returns:
False if the operator is not support and True if it is supported.
bool: False if rejected by constraints; otherwise, True.

"""
# Select decomposes into squeeze, which in turn becomes a view. Therefore,
# perform the same check on select operators as view operators.
Expand Down Expand Up @@ -279,14 +370,40 @@ def is_node_supported(


class EthosU55TransposeCheck(OperatorSupportBase):
"""Validate permute nodes against U55 reshape/transpose limits.

Applies dtype- and rank-specific constraints to permutations. Tests both
NCHW and NHWC interpretations for rank-3/4 shapes since dim order is unknown
at partition time.

Attributes:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""

def __init__(self, reporter: WhyNoPartitionReporter):
"""Initialize the check with a reporter.

Args:
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.

"""
super().__init__()
self.reporter = reporter

def _pad_to_rank_4(
self, shape: shape_t, permutation: list[int]
) -> tuple[shape_t, shape_t]:
"""Pad shape/permutation to rank 4 by prepending ones/indices.

Args:
shape (list[int]): Original shape.
permutation (list[int]): Original permutation indices.

Returns:
tuple[list[int], list[int]]: Padded shape and permutation.

"""
diff = 4 - len(shape)
padded_shape = [1] * diff + shape
for i in range(len(permutation)):
Expand All @@ -295,6 +412,15 @@ def _pad_to_rank_4(
return padded_shape, padded_permutation

def axes_product(self, nhwc_shape: shape_t) -> int:
"""Return the product of all axes in ``nhwc_shape``.

Args:
nhwc_shape (list[int]): Shape in NHWC order.

Returns:
int: Product of the axis sizes.

"""
product = 1
for axes in nhwc_shape:
product *= axes
Expand All @@ -303,7 +429,7 @@ def axes_product(self, nhwc_shape: shape_t) -> int:
def _permute_constraint_i8_i16(
self, nhwc_shape: list[int], permutation: list[int]
) -> bool:
"""Returns True if the constraints are ok."""
"""Return True if permutation meets i8/i16 constraints."""
N, H, W, C = nhwc_shape
match permutation:
case (0, 1, 2, 3): # NHWC -> NHWC
Expand All @@ -316,7 +442,7 @@ def _permute_constraint_i8_i16(
def _permute_constraint_i32(
self, nhwc_shape: list[int], permutation: list[int]
) -> bool:
"""Returns True if the constraints are ok."""
"""Return True if permutation meets i32 constraints."""
N, H, W, C = nhwc_shape
match permutation:
case (0, 1, 2, 3): # NHWC -> NHWC
Expand All @@ -329,6 +455,7 @@ def _permute_constraint_i32(
return False

def _permute_constraint(self, shape, permutation, dtype):
"""Return True if permutation meets dtype-specific constraints."""
if dtype in (torch.int8, torch.int16):
return self._permute_constraint_i8_i16(shape, permutation)
if dtype == torch.int32:
Expand All @@ -338,7 +465,19 @@ def _permute_constraint(self, shape, permutation, dtype):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
"""Return True if a permute node satisfies U55 constraints.

Tests both NCHW and NHWC interpretations for rank-3/4 shapes, and
applies dtype-specific limits to shapes and permutations.

Args:
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
node (fx.Node): FX node to check.

Returns:
bool: True if supported; otherwise, False.

"""
if not node.target == exir_ops.edge.aten.permute_copy.default:
return True

Expand Down
Loading