Skip to content

Commit 8115da4

Browse files
authored
Arm backend: Initial int16 extension (#13318)
### Summary Add TOSA extension support for tests and add initial support for int16. ### Test plan Tested through unit tests in backends/arm. Signed-off-by: Per Åstrand <[email protected]>
1 parent c4c568a commit 8115da4

File tree

4 files changed

+56
-18
lines changed

4 files changed

+56
-18
lines changed

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def ethosu_compile_spec(
128128
self.compiler_flags.append("--output-format=raw")
129129
self.compiler_flags.append("--debug-force-regor")
130130

131-
base_tosa_version = "TOSA-1.0+INT"
131+
base_tosa_version = "TOSA-1.0+INT+int16"
132132
if "u55" in target:
133133
# Add the Ethos-U55 extension marker
134134
base_tosa_version += "+u55"

backends/arm/test/ops/test_sigmoid_16bit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def get_16bit_sigmoid_quantizer(u55_config=False):
4141
tosa_version = conftest.get_option("tosa_version")
4242
tosa_profiles = {
4343
"1.0": TosaSpecification.create_from_string(
44-
"TOSA-1.0+INT" + ("+u55" if u55_config else "")
44+
"TOSA-1.0+INT+int16" + ("+u55" if u55_config else "")
4545
),
4646
}
4747

@@ -94,6 +94,7 @@ def test_sigmoid_tosa_INT(test_data):
9494
Sigmoid.aten_op,
9595
Sigmoid.exir_op,
9696
qtol=1,
97+
tosa_extensions=["int16"],
9798
)
9899
pipeline.change_args("quantize", get_16bit_sigmoid_quantizer())
99100
pipeline.run()
@@ -114,7 +115,9 @@ def test_sigmoid_tosa_INT_add_sigmoid(test_data):
114115
Sigmoid.aten_op,
115116
Sigmoid.exir_op,
116117
qtol=1,
118+
tosa_extensions=["int16"],
117119
)
120+
pipeline.change_args("quantize", get_16bit_sigmoid_quantizer())
118121
pipeline.run()
119122

120123

@@ -154,6 +157,7 @@ def test_sigmoid_u55_INT_add_sigmoid(test_data):
154157
n_expected_delegates=1,
155158
quantize=True,
156159
u55_subset=True,
160+
tosa_extensions=["int16"],
157161
)
158162
pipeline.change_args("quantize", get_16bit_sigmoid_quantizer(True))
159163
pipeline.run()

backends/arm/test/ops/test_sigmoid_32bit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_32bit_sigmoid_quantizer(u55_config=False):
5757
tosa_version = conftest.get_option("tosa_version")
5858
tosa_profiles = {
5959
"1.0": TosaSpecification.create_from_string(
60-
"TOSA-1.0+INT" + ("+u55" if u55_config else "")
60+
"TOSA-1.0+INT+int16" + ("+u55" if u55_config else "")
6161
),
6262
}
6363

@@ -110,6 +110,7 @@ def test_sigmoid_tosa_INT(test_data):
110110
Sigmoid.aten_op,
111111
Sigmoid.exir_op,
112112
qtol=1,
113+
tosa_extensions=["int16"],
113114
)
114115
pipeline.change_args("quantize", get_32bit_sigmoid_quantizer())
115116
pipeline.run()
@@ -123,6 +124,7 @@ def test_sigmoid_tosa_INT_add_sigmoid(test_data):
123124
Sigmoid.aten_op,
124125
Sigmoid.exir_op,
125126
qtol=1,
127+
tosa_extensions=["int16"],
126128
)
127129
pipeline.change_args("quantize", get_32bit_sigmoid_quantizer())
128130
pipeline.run()
@@ -136,6 +138,7 @@ def test_sigmoid_u55_INT(test_data):
136138
{Sigmoid.exir_op: 1},
137139
quantize=True,
138140
u55_subset=True,
141+
tosa_extensions=["int16"],
139142
)
140143
pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True))
141144
pipeline.run()
@@ -150,6 +153,7 @@ def test_sigmoid_u55_INT_add_sigmoid(test_data):
150153
n_expected_delegates=1,
151154
quantize=True,
152155
u55_subset=True,
156+
tosa_extensions=["int16"],
153157
)
154158
pipeline.change_args("quantize", get_32bit_sigmoid_quantizer(True))
155159
pipeline.run()

backends/arm/test/tester/test_pipeline.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,14 @@ def __init__(
306306
rtol: float = 1e-03,
307307
qtol: int = 1,
308308
dynamic_shapes: Optional[Tuple[Any]] = None,
309+
tosa_extensions: Optional[List[str]] = None,
309310
):
311+
if tosa_extensions is None:
312+
tosa_extensions = []
310313
tosa_profiles = {
311-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
314+
"1.0": TosaSpecification.create_from_string(
315+
"TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions])
316+
),
312317
}
313318
tosa_version = conftest.get_option("tosa_version")
314319

@@ -406,9 +411,14 @@ def __init__(
406411
transform_passes: Optional[
407412
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
408413
] = None,
414+
tosa_extensions: Optional[List[str]] = None,
409415
):
416+
if tosa_extensions is None:
417+
tosa_extensions = []
410418
tosa_profiles = {
411-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+FP"),
419+
"1.0": TosaSpecification.create_from_string(
420+
"TOSA-1.0+FP" + "".join([f"+{ext}" for ext in tosa_extensions])
421+
),
412422
}
413423
tosa_version = conftest.get_option("tosa_version")
414424

@@ -655,10 +665,15 @@ def __init__(
655665
pass_functions: Optional[List[Callable]] = None,
656666
passes_with_exported_program: Optional[List[Type[ExportPass]]] = None,
657667
custom_path: str = None,
668+
tosa_extensions: Optional[List[str]] = None,
658669
):
670+
if tosa_extensions is None:
671+
tosa_extensions = []
659672
tosa_profiles = {
660673
"1.0": TosaSpecification.create_from_string(
661-
"TOSA-1.0+" + ("INT" if quantize else "FP")
674+
"TOSA-1.0+"
675+
+ ("INT" if quantize else "FP")
676+
+ "".join([f"+{ext}" for ext in tosa_extensions]),
662677
),
663678
}
664679
tosa_version = conftest.get_option("tosa_version")
@@ -721,9 +736,14 @@ def __init__(
721736
module: torch.nn.Module,
722737
test_data: T,
723738
custom_path: str = None,
739+
tosa_extensions: Optional[List[str]] = None,
724740
):
741+
if tosa_extensions is None:
742+
tosa_extensions = []
725743
tosa_profiles = {
726-
"1.0": TosaSpecification.create_from_string("TOSA-1.0+INT"),
744+
"1.0": TosaSpecification.create_from_string(
745+
"TOSA-1.0+INT" + "".join([f"+{ext}" for ext in tosa_extensions]),
746+
),
727747
}
728748
tosa_version = conftest.get_option("tosa_version")
729749

@@ -779,18 +799,23 @@ def __init__(
779799
custom_path: str = None,
780800
quantize: Optional[bool] = False,
781801
u55_subset: Optional[bool] = False,
802+
tosa_extensions: Optional[List[str]] = None,
782803
):
804+
if tosa_extensions is None:
805+
tosa_extensions = []
783806
tosa_profiles = {
784-
"1.0": "TOSA-1.0+" + ("INT" if quantize else "FP"),
807+
"1.0": TosaSpecification.create_from_string(
808+
"TOSA-1.0+"
809+
+ ("INT" if quantize else "FP")
810+
+ ("+u55" if u55_subset and quantize else "")
811+
+ "".join([f"+{ext}" for ext in tosa_extensions]),
812+
),
785813
}
786-
tosa_version = tosa_profiles[conftest.get_option("tosa_version")]
814+
tosa_version = conftest.get_option("tosa_version")
787815

788-
if u55_subset and quantize:
789-
tosa_version = f"{tosa_version}+u55"
816+
tosa_spec = tosa_profiles[tosa_version]
790817

791-
compile_spec = common.get_tosa_compile_spec(
792-
tosa_version, custom_path=custom_path
793-
)
818+
compile_spec = common.get_tosa_compile_spec(tosa_spec, custom_path=custom_path)
794819
super().__init__(
795820
module,
796821
test_data,
@@ -799,7 +824,7 @@ def __init__(
799824
[],
800825
)
801826

802-
if "INT" in tosa_version:
827+
if tosa_spec.support_integer():
803828
self.add_stage(self.tester.quantize, pos=0)
804829

805830
self.change_args("check_not.exir", [])
@@ -855,11 +880,16 @@ def __init__(
855880
transform_passes: Optional[
856881
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
857882
] = None,
883+
tosa_extensions: Optional[List[str]] = None,
858884
):
859885

860-
tosa_profile = TosaSpecification.create_from_string(tosa_version)
886+
if tosa_extensions is None:
887+
tosa_extensions = []
888+
tosa_spec = TosaSpecification.create_from_string(
889+
tosa_version + "".join([f"+{ext}" for ext in tosa_extensions])
890+
)
861891
compile_spec = common.get_vgf_compile_spec(
862-
tosa_profile, compiler_flags=vgf_compiler_flags, custom_path=custom_path
892+
tosa_spec, compiler_flags=vgf_compiler_flags, custom_path=custom_path
863893
)
864894

865895
super().__init__(
@@ -873,7 +903,7 @@ def __init__(
873903
transform_passes=transform_passes,
874904
)
875905

876-
if "INT" in tosa_version:
906+
if tosa_spec.support_integer():
877907
quantizer = VgfQuantizer(compile_spec)
878908
quantization_config = get_symmetric_quantization_config(
879909
is_per_channel=per_channel_quantization

0 commit comments

Comments
 (0)