diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 20d4f41a273..4123d217e94 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -59,7 +59,7 @@ UnsqueezeScalarPlaceholdersPass, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.transforms.decompose_sdpa import ( DecomposeScaledDotProductAttention, ) @@ -92,7 +92,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(ConvertMinMaxPass()) self.add_pass(ConvertAnyDefaultDimDimsPass()) self.add_pass(MatchWhereSelfDtypePass()) - if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + if self.tosa_spec.is_U55_subset: self.add_pass(CastToInt32Pass()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) @@ -210,7 +210,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule): self.add_pass(DecomposeSqrtPass()) self.add_pass(DecomposeSiluPass()) - if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + if self.tosa_spec.is_U55_subset: # Numerically stable softmax uses amax which is not supported on Ethos-U55 self.add_pass(DecomposeSoftmaxUnstablePass()) else: diff --git a/backends/arm/operator_support/convolution_support.py b/backends/arm/operator_support/convolution_support.py index 5b4fefdbf81..3e3149f3443 100644 --- a/backends/arm/operator_support/convolution_support.py +++ b/backends/arm/operator_support/convolution_support.py @@ -11,11 +11,8 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) -from executorch.backends.arm.tosa_specification import ( - Tosa_0_80, - Tosa_1_00, - TosaSpecification, -) +from executorch.backends.arm.tosa_specification import TosaSpecification + from executorch.exir.dialects._ops import ops as exir_ops @@ -46,13 +43,10 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): return False # Hardware specific constraints - if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): - # TODO remove this once TOSA 1.0 support for u55 is added. - if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions: - return False - return True - else: + if tosa_spec.is_U55_subset: return self._is_node_supported_u55(node) + else: + return True def _is_node_supported_u55(self, node: fx.Node): """Hardware constraints for Ethos-U-55 case, Vela 4.2.0 (25.02 release)""" diff --git a/backends/arm/operator_support/pool_2d_support.py b/backends/arm/operator_support/pool_2d_support.py index f4ada36de80..753cd7c747b 100644 --- a/backends/arm/operator_support/pool_2d_support.py +++ b/backends/arm/operator_support/pool_2d_support.py @@ -11,7 +11,7 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops @@ -46,7 +46,7 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + if not tosa_spec.is_U55_subset: return True # U55 case, Vela 4.2.0 (25.02 release) @@ -104,7 +104,7 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + if not tosa_spec.is_U55_subset: return True # U55 case, Vela 4.2.0 (25.02 release) diff --git a/backends/arm/operator_support/reduce_sum_support.py b/backends/arm/operator_support/reduce_sum_support.py index a50bcbceab7..4d0614d4b1a 100644 --- a/backends/arm/operator_support/reduce_sum_support.py +++ b/backends/arm/operator_support/reduce_sum_support.py @@ -10,7 +10,7 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops @@ -26,7 +26,7 @@ class SumSupported(SupportedTOSAOperatorCheck): ] def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): - if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset): + if not tosa_spec.is_U55_subset: return True # U55 case, Vela 4.2.0 (25.02 release) diff --git a/backends/arm/operator_support/right_shift_support.py b/backends/arm/operator_support/right_shift_support.py index 49976b2346f..d18950a58a2 100644 --- a/backends/arm/operator_support/right_shift_support.py +++ b/backends/arm/operator_support/right_shift_support.py @@ -13,7 +13,7 @@ register_tosa_support_check, SupportedTOSAOperatorCheck, ) -from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir.dialects._ops import ops as exir_ops logger = logging.getLogger(__name__) @@ -36,6 +36,6 @@ class RightShiftSupported(SupportedTOSAOperatorCheck): def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification): # TODO MLETORCH-525 Remove warning - if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: + if tosa_spec.is_U55_subset: logging.warning(f"{node.target} may introduce one-off errors.") return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index c732c91a20a..547eafbfa8d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -24,11 +24,7 @@ EthosU55NotSupported, EthosU55TransposeCheck, ) -from executorch.backends.arm.tosa_specification import ( - Tosa_0_80, - Tosa_1_00, - TosaSpecification, -) +from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops @@ -129,9 +125,7 @@ def tosa_support_factory( if not tosa_spec.support_float(): negative_checks.append(NeedsDecompositionCheck(reporter)) negative_checks.append(CheckProperQuantization(reporter)) - if (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset) or ( - isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions - ): + if tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) negative_checks.append(EthosU55TransposeCheck(reporter)) diff --git a/backends/arm/operators/op_rshift_tensor.py b/backends/arm/operators/op_rshift_tensor.py index e843f669a58..ece6debeab4 100644 --- a/backends/arm/operators/op_rshift_tensor.py +++ b/backends/arm/operators/op_rshift_tensor.py @@ -17,7 +17,6 @@ validate_num_inputs, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_specification import Tosa_0_80, Tosa_1_00 @register_node_visitor @@ -39,7 +38,7 @@ def define_node( attr = ts.TosaSerializerAttribute() round = False - if isinstance(self.tosa_spec, Tosa_0_80) and self.tosa_spec.is_U55_subset: + if self.tosa_spec.is_U55_subset: # U55 only supports INT32 and round == True # TODO MLETORCH-525 Emulate round == False with different decomposition round = True @@ -72,7 +71,7 @@ def define_node( attr = ts.TosaSerializerAttribute() round = False - if isinstance(self.tosa_spec, Tosa_1_00) and "u55" in self.tosa_spec.extensions: + if self.tosa_spec.is_U55_subset: # U55 only supports INT32 and round == True # TODO MLETORCH-525 Emulate round == False with different decomposition round = True diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 480497b4aee..13e2f80b5c5 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -293,7 +293,9 @@ def __init__( ) quant_stage = ( Quantize( - TOSAQuantizer(compile_spec).set_io(get_symmetric_quantization_config()), + TOSAQuantizer(tosa_profiles[tosa_version]).set_io( + get_symmetric_quantization_config() + ), get_symmetric_quantization_config(), ) if symmetric_io_quantization diff --git a/backends/arm/tosa_specification.py b/backends/arm/tosa_specification.py index 640361e059c..0cf5cfab74d 100644 --- a/backends/arm/tosa_specification.py +++ b/backends/arm/tosa_specification.py @@ -36,6 +36,7 @@ class TosaSpecification: """ version: Version + is_U55_subset: bool def support_integer(self) -> bool: """ @@ -49,9 +50,13 @@ def support_float(self) -> bool: """ raise NotImplementedError - def __init__(self, version: Version): + def __init__(self, version: Version, extras: List[str]): self.version = version + self.is_U55_subset = "u55" in extras + if self.is_U55_subset: + extras.remove("u55") + @staticmethod def create_from_string(repr: str) -> "TosaSpecification": """ @@ -85,11 +90,10 @@ def create_from_string(repr: str) -> "TosaSpecification": class Tosa_0_80(TosaSpecification): profile: str level_8k: bool - is_U55_subset: bool available_profiles = ["BI", "MI"] # MT is not defined def __init__(self, version: Version, extras: List[str]): - super().__init__(version) + super().__init__(version, extras) assert version >= Version("0.80") and version < Version("0.90") # Check that we only have one profile in the extensions list @@ -105,9 +109,6 @@ def __init__(self, version: Version, extras: List[str]): self.level_8k = "8k" in extras if self.level_8k: extras.remove("8k") - self.is_U55_subset = "u55" in extras - if self.is_U55_subset: - extras.remove("u55") if len(extras) > 0: raise ValueError(f"Unhandled extras found: {extras}") @@ -147,7 +148,7 @@ class Tosa_1_00(TosaSpecification): } def __init__(self, version: Version, extras: List[str]): - super().__init__(version) + super().__init__(version, extras) # Check that we have at least one profile in the extensions list if [e in Tosa_1_00.available_profiles for e in extras].count(True) == 0: @@ -194,6 +195,8 @@ def __repr__(self): extensions = self._get_extensions_string() if self.level_8k: extensions += "+8k" + if self.is_U55_subset: + extensions += "+u55" return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}" def __hash__(self) -> int: