diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index bdbbbfd1162..ece26ae4f81 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -10,12 +10,15 @@ # backends. Converts via TOSA as an intermediate form supported by AoT and # JIT compiler flows. # - from typing import List, Optional -from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_specification import ( # type: ignore[import-not-found] + TosaSpecification, +) -from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found] + CompileSpec, +) class ArmCompileSpecBuilder: @@ -28,6 +31,7 @@ def __init__(self): def vgf_compile_spec( self, + tosa_spec: TosaSpecification = None, # type: ignore[assignment] compiler_flags: Optional[str] = "", ) -> "ArmCompileSpecBuilder": """ @@ -40,7 +44,33 @@ def vgf_compile_spec( self.compiler_flags = [ compiler_flags, ] - self.tosa_spec = TosaSpecification.create_from_string("TOSA-0.80+MI") + + if tosa_spec is None: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + + tosa_version = tosa_spec.version # type: ignore[attr-defined] + tosa_profiles = tosa_spec.profiles # type: ignore[attr-defined] + + if tosa_version.major != 1: + raise ValueError( + "Arm backend only supports converter-backend for TOSA version 1. " + f"Invalid TOSA version: {tosa_version}" + ) + + if not ("FP" or "INT" in tosa_profiles): + raise ValueError( + "Arm backend only supports converter-backend for FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + if len(tosa_profiles) != 1: + raise ValueError( + "For now Arm backend only supports converter-backend for either FP or INT. " + f"Invalid TOSA profile: {tosa_profiles}" + ) + + self.tosa_spec = tosa_spec + return self def ethosu_compile_spec( diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index 5449ced09b9..148f9c1d477 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -385,6 +385,7 @@ def get_compile_spec( intermediates: Optional[str] = None, system_config: Optional[str] = None, memory_mode: Optional[str] = None, + quantize: bool = False, ) -> list[CompileSpec]: spec_builder = None if target.startswith("TOSA"): @@ -401,7 +402,11 @@ def get_compile_spec( extra_flags="--verbose-operators --verbose-cycle-estimate", ) elif "vgf" in target: - spec_builder = ArmCompileSpecBuilder().vgf_compile_spec() + if quantize: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+INT") + else: + tosa_spec = TosaSpecification.create_from_string("TOSA-1.0+FP") + spec_builder = ArmCompileSpecBuilder().vgf_compile_spec(tosa_spec) if intermediates is not None: spec_builder.dump_intermediate_artifacts_to(intermediates) @@ -700,6 +705,7 @@ def to_edge_TOSA_delegate( args.intermediates, args.system_config, args.memory_mode, + args.quantize, ) model_int8 = None @@ -739,6 +745,7 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_ args.intermediates, args.system_config, args.memory_mode, + args.quantize, ) model, exported_program = quantize_model( args, model, example_inputs, compile_spec