Skip to content

Commit 6531b4a

Browse files
Arm backend: Add U55 operator check for cast (#14577)
U55 does not support casting from int32 or any cast involving booleans. This patch introduces a new U55 operator check that will reject any such casts. Note that bool->bool and int32->int32 are considered ok as these will not result in any cast operation. Signed-off-by: Oscar Andersson <[email protected]>
1 parent 4622edb commit 6531b4a

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,3 +384,63 @@ def is_node_supported(
384384
return False
385385

386386
return True
387+
388+
389+
class EthosU55CastCheck(OperatorSupportBase):
390+
"""Reject unsupported casts on U55.
391+
392+
U55 does not support casting from INT32 or any casts involving BOOL. Note that
393+
casting from one dtype to the same dtype is a no-op and is supported.
394+
395+
396+
Attributes:
397+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
398+
399+
"""
400+
401+
targets = [
402+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
403+
]
404+
405+
def __init__(self, reporter: WhyNoPartitionReporter):
406+
"""Initialize the check with a reporter.
407+
408+
Args:
409+
reporter (WhyNoPartitionReporter): Reporter for rejection reasons.
410+
411+
"""
412+
super().__init__()
413+
self.reporter = reporter
414+
415+
def is_node_supported(
416+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
417+
) -> bool:
418+
"""Return True if the node satisfies the cast constraints of U55.
419+
420+
Args:
421+
submodules (typing.Mapping[str, torch.nn.Module]): Exported modules.
422+
node (fx.Node): FX node to check.
423+
424+
Returns:
425+
bool: True if supported; otherwise, False.
426+
427+
"""
428+
if node.target not in self.targets:
429+
return True
430+
input_dtype = get_first_fake_tensor(node.all_input_nodes[0]).dtype
431+
output_dtype = get_first_fake_tensor(node).dtype
432+
if input_dtype == output_dtype:
433+
# This is ok as this will not result in a cast
434+
return True
435+
if input_dtype in (torch.bool, torch.int32):
436+
self.reporter.report_reject(
437+
node, f"Casting from {input_dtype} is not supported on U55."
438+
)
439+
return False
440+
if output_dtype in (torch.bool,):
441+
self.reporter.report_reject(
442+
node, f"Casting to {output_dtype} is not supported on U55."
443+
)
444+
return False
445+
446+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
2222
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2323
from executorch.backends.arm.operator_support.ethos_u55_support import (
24+
EthosU55CastCheck,
2425
EthosU55DtypeSupport,
2526
EthosU55NotSupported,
2627
EthosU55TransposeCheck,
@@ -141,6 +142,7 @@ def tosa_support_factory(
141142
negative_checks.append(EthosU55DtypeSupport(reporter))
142143
negative_checks.append(EthosU55TransposeCheck(reporter))
143144
negative_checks.append(EthosU55ViewCheck(reporter))
145+
negative_checks.append(EthosU55CastCheck(reporter))
144146

145147
return chain(
146148
reporter.wrap_check(

backends/arm/test/ops/test_to_copy.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,32 @@ def test_to_tosa_INT_not_delegated_REDUNDANT_CAST(test_data: Tuple):
244244
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
245245
)
246246
pipeline.run()
247+
248+
249+
_TO_COPY_DATA_INT_U55_REJECT = {
250+
"rand_bool_int8": lambda: (
251+
torch.randint(0, 2, (1, 2, 3, 4), dtype=torch.bool),
252+
torch.int8,
253+
),
254+
"rand_int16_bool": lambda: (
255+
torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int16),
256+
torch.bool,
257+
),
258+
"rand_int32_int8": lambda: (
259+
torch.randint(-1000, 1000, (1, 2, 3, 4), dtype=torch.int32),
260+
torch.int8,
261+
),
262+
}
263+
264+
265+
@common.parametrize("test_data", _TO_COPY_DATA_INT_U55_REJECT)
266+
def test_to_u55_INT(test_data: Tuple):
267+
test_tensor, new_dtype = test_data()
268+
pipeline = OpNotSupportedPipeline[input_t1](
269+
Cast(new_dtype),
270+
(test_tensor,),
271+
u55_subset=True,
272+
quantize=True,
273+
non_delegated_ops={}, # These are removed outside of the Arm backend so the graph is empty
274+
)
275+
pipeline.run()

0 commit comments

Comments
 (0)