Skip to content

Commit d98a3e9

Browse files
committed
Arm backend: Add extension support for test pipelines
Adds the possibility to add extension to the TOSA specification for the test pipeline. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I64d8dcce3ea1bf6aa76230396690fd007974f792
1 parent baed71b commit d98a3e9

File tree

1 file changed

+45
-15
lines changed

1 file changed

+45
-15
lines changed

backends/arm/test/tester/test_pipeline.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,14 @@ def __init__(
306306
rtol: float = 1e-03,
307307
qtol: int = 1,
308308
dynamic_shapes: Optional[Tuple[Any]] = None,
309+
tosa_extensions: Optional[List[str]] = None,
309310
):
311+
if tosa_extensions is None:
312+
tosa_extensions = []
310313
tosa_profiles = {
311-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
314+
"1.0": TosaSpecification.create_from_string(
315+
"TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions])
316+
),
312317
}
313318
tosa_version = conftest.get_option("tosa_version")
314319

@@ -406,9 +411,14 @@ def __init__(
406411
transform_passes: Optional[
407412
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
408413
] = None,
414+
tosa_extensions: Optional[List[str]] = None,
409415
):
416+
if tosa_extensions is None:
417+
tosa_extensions = []
410418
tosa_profiles = {
411-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+FP"),
419+
"1.0": TosaSpecification.create_from_string(
420+
"TOSA-1.0+FP" + "".join([f"+{ext}" for ext in tosa_extensions])
421+
),
412422
}
413423
tosa_version = conftest.get_option("tosa_version")
414424

@@ -655,10 +665,15 @@ def __init__(
655665
pass_functions: Optional[List[Callable]] = None,
656666
passes_with_exported_program: Optional[List[Type[ExportPass]]] = None,
657667
custom_path: str = None,
668+
tosa_extensions: Optional[List[str]] = None,
658669
):
670+
if tosa_extensions is None:
671+
tosa_extensions = []
659672
tosa_profiles = {
660673
"1.0": TosaSpecification.create_from_string(
661-
"TOSA-1.0+" + ("INT" if quantize else "FP")
674+
"TOSA-1.0+"
675+
+ ("INT" if quantize else "FP")
676+
+ "".join([f"+{ext}" for ext in tosa_extensions]),
662677
),
663678
}
664679
tosa_version = conftest.get_option("tosa_version")
@@ -721,9 +736,14 @@ def __init__(
721736
module: torch.nn.Module,
722737
test_data: T,
723738
custom_path: str = None,
739+
tosa_extensions: Optional[List[str]] = None,
724740
):
741+
if tosa_extensions is None:
742+
tosa_extensions = []
725743
tosa_profiles = {
726-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
744+
"1.0": TosaSpecification.create_from_string(
745+
"TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]),
746+
),
727747
}
728748
tosa_version = conftest.get_option("tosa_version")
729749

@@ -779,18 +799,23 @@ def __init__(
779799
custom_path: str = None,
780800
quantize: Optional[bool] = False,
781801
u55_subset: Optional[bool] = False,
802+
tosa_extensions: Optional[List[str]] = None,
782803
):
804+
if tosa_extensions is None:
805+
tosa_extensions = []
783806
tosa_profiles = {
784-
"1.0": "TOSA-1.0+" + ("INT" if quantize else "FP"),
807+
"1.0": TosaSpecification.create_from_string(
808+
"TOSA-1.0+"
809+
+ ("INT" if quantize else "FP")
810+
+ ("+u55" if u55_subset and quantize else "")
811+
+ "".join([f"+{ext}" for ext in tosa_extensions]),
812+
),
785813
}
786-
tosa_version = tosa_profiles[conftest.get_option("tosa_version")]
814+
tosa_version = conftest.get_option("tosa_version")
787815

788-
if u55_subset and quantize:
789-
tosa_version = f"{tosa_version}+u55"
816+
tosa_spec = tosa_profiles[tosa_version]
790817

791-
compile_spec = common.get_tosa_compile_spec(
792-
tosa_version, custom_path=custom_path
793-
)
818+
compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path)
794819
super().__init__(
795820
module,
796821
test_data,
@@ -799,7 +824,7 @@ def __init__(
799824
[],
800825
)
801826

802-
if "INT" in tosa_version:
827+
if tosa_spec.support_integer():
803828
self.add_stage(self.tester.quantize, pos=0)
804829

805830
self.change_args("check_not.exir", [])
@@ -855,11 +880,16 @@ def __init__(
855880
transform_passes: Optional[
856881
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
857882
] = None,
883+
tosa_extensions: Optional[List[str]] = None,
858884
):
859885

860-
tosa_profile = TosaSpecification.create_from_string(tosa_version)
886+
if tosa_extensions is None:
887+
tosa_extensions = []
888+
tosa_spec = TosaSpecification.create_from_string(
889+
tosa_version + "".join([f"+{ext}" for ext in tosa_extensions])
890+
)
861891
compile_spec = common.get_vgf_compile_spec(
862-
tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path
892+
tosa_spec, compiler_flags=vgf_compiler_flags, custom_path=custom_path
863893
)
864894

865895
super().__init__(
@@ -873,7 +903,7 @@ def __init__(
873903
transform_passes=transform_passes,
874904
)
875905

876-
if "INT" in tosa_version:
906+
if tosa_spec.support_integer():
877907
quantizer = VgfQuantizer(compile_spec)
878908
quantization_config = get_symmetric_quantization_config(
879909
is_per_channel=per_channel_quantization

0 commit comments

Comments
 (0)