diff --git a/backends/arm/arm_backend.py b/backends/arm/arm_backend.py index e2335c07b87..173cba5b508 100644 --- a/backends/arm/arm_backend.py +++ b/backends/arm/arm_backend.py @@ -128,7 +128,7 @@ def ethosu_compile_spec( self.compiler_flags.append("--output-format=raw") self.compiler_flags.append("--debug-force-regor") - base_tosa_version = "TOSA-1.0+INT" + base_tosa_version = "TOSA-1.0+INT+int16" if "u55" in target: # Add the Ethos-U55 extension marker base_tosa_version += "+u55" diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index a41681675ce..3d70881a3f0 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -41,7 +41,7 @@ def get_16bit_sigmoid_quantizer(u55_config=False): tosa_version = conftest.get_option("tosa_version") tosa_profiles = { "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+INT" + ("+u55" if u55_config else "") + "TOSA-1.0+INT+int16" + ("+u55" if u55_config else "") ), } @@ -94,6 +94,7 @@ def test_sigmoid_tosa_INT(test_data): Sigmoid.aten_op, Sigmoid.exir_op, qtol=1, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) pipeline.run() @@ -114,7 +115,9 @@ def test_sigmoid_tosa_INT_add_sigmoid(test_data): Sigmoid.aten_op, Sigmoid.exir_op, qtol=1, + tosa_extensions=["int16"], ) + pipeline.change_args("quantize", get_16bit_sigmoid_quantizer()) pipeline.run() @@ -154,6 +157,7 @@ def test_sigmoid_u55_INT_add_sigmoid(test_data): n_expected_delegates=1, quantize=True, u55_subset=True, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True)) pipeline.run() diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 7d2e649bcd8..553a852b245 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -57,7 +57,7 @@ def get_32bit_sigmoid_quantizer(u55_config=False): tosa_version = conftest.get_option("tosa_version") tosa_profiles = { "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+INT" + ("+u55" if u55_config else "") + "TOSA-1.0+INT+int16" + ("+u55" if u55_config else "") ), } @@ -110,6 +110,7 @@ def test_sigmoid_tosa_INT(test_data): Sigmoid.aten_op, Sigmoid.exir_op, qtol=1, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) pipeline.run() @@ -123,6 +124,7 @@ def test_sigmoid_tosa_INT_add_sigmoid(test_data): Sigmoid.aten_op, Sigmoid.exir_op, qtol=1, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer()) pipeline.run() @@ -136,6 +138,7 @@ def test_sigmoid_u55_INT(test_data): {Sigmoid.exir_op: 1}, quantize=True, u55_subset=True, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True)) pipeline.run() @@ -150,6 +153,7 @@ def test_sigmoid_u55_INT_add_sigmoid(test_data): n_expected_delegates=1, quantize=True, u55_subset=True, + tosa_extensions=["int16"], ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True)) pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index fb9f05444e5..cbe3f5f613d 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -306,9 +306,14 @@ def __init__( rtol: float = 1e-03, qtol: int = 1, dynamic_shapes: Optional[Tuple[Any]] = None, + tosa_extensions: Optional[List[str]] = None, ): + if tosa_extensions is None: + tosa_extensions = [] tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), + "1.0": TosaSpecification.create_from_string( + "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]) + ), } tosa_version = conftest.get_option("tosa_version") @@ -406,9 +411,14 @@ def __init__( transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, + tosa_extensions: Optional[List[str]] = None, ): + if tosa_extensions is None: + tosa_extensions = [] tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+FP"), + "1.0": TosaSpecification.create_from_string( + "TOSA-1.0+FP" + "".join([f"+{ext}" for ext in tosa_extensions]) + ), } tosa_version = conftest.get_option("tosa_version") @@ -655,10 +665,15 @@ def __init__( pass_functions: Optional[List[Callable]] = None, passes_with_exported_program: Optional[List[Type[ExportPass]]] = None, custom_path: str = None, + tosa_extensions: Optional[List[str]] = None, ): + if tosa_extensions is None: + tosa_extensions = [] tosa_profiles = { "1.0": TosaSpecification.create_from_string( - "TOSA-1.0+" + ("INT" if quantize else "FP") + "TOSA-1.0+" + + ("INT" if quantize else "FP") + + "".join([f"+{ext}" for ext in tosa_extensions]), ), } tosa_version = conftest.get_option("tosa_version") @@ -721,9 +736,14 @@ def __init__( module: torch.nn.Module, test_data: T, custom_path: str = None, + tosa_extensions: Optional[List[str]] = None, ): + if tosa_extensions is None: + tosa_extensions = [] tosa_profiles = { - "1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"), + "1.0": TosaSpecification.create_from_string( + "TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]), + ), } tosa_version = conftest.get_option("tosa_version") @@ -779,18 +799,23 @@ def __init__( custom_path: str = None, quantize: Optional[bool] = False, u55_subset: Optional[bool] = False, + tosa_extensions: Optional[List[str]] = None, ): + if tosa_extensions is None: + tosa_extensions = [] tosa_profiles = { - "1.0": "TOSA-1.0+" + ("INT" if quantize else "FP"), + "1.0": TosaSpecification.create_from_string( + "TOSA-1.0+" + + ("INT" if quantize else "FP") + + ("+u55" if u55_subset and quantize else "") + + "".join([f"+{ext}" for ext in tosa_extensions]), + ), } - tosa_version = tosa_profiles[conftest.get_option("tosa_version")] + tosa_version = conftest.get_option("tosa_version") - if u55_subset and quantize: - tosa_version = f"{tosa_version}+u55" + tosa_spec = tosa_profiles[tosa_version] - compile_spec = common.get_tosa_compile_spec( - tosa_version, custom_path=custom_path - ) + compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path) super().__init__( module, test_data, @@ -799,7 +824,7 @@ def __init__( [], ) - if "INT" in tosa_version: + if tosa_spec.support_integer(): self.add_stage(self.tester.quantize, pos=0) self.change_args("check_not.exir", []) @@ -855,11 +880,16 @@ def __init__( transform_passes: Optional[ Union[Sequence[PassType], Dict[str, Sequence[PassType]]] ] = None, + tosa_extensions: Optional[List[str]] = None, ): - tosa_profile = TosaSpecification.create_from_string(tosa_version) + if tosa_extensions is None: + tosa_extensions = [] + tosa_spec = TosaSpecification.create_from_string( + tosa_version + "".join([f"+{ext}" for ext in tosa_extensions]) + ) compile_spec = common.get_vgf_compile_spec( - tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path + tosa_spec, compiler_flags=vgf_compiler_flags, custom_path=custom_path ) super().__init__( @@ -873,7 +903,7 @@ def __init__( transform_passes=transform_passes, ) - if "INT" in tosa_version: + if tosa_spec.support_integer(): quantizer = VgfQuantizer(compile_spec) quantization_config = get_symmetric_quantization_config( is_per_channel=per_channel_quantization