Skip to content

Commit 8b9662c

Browse files
perkeyprocedure
authored andcommitted
Arm backend: Add support for TOSA 1.0 serializer (pytorch#10135)
Adapt serialization and TOSA graph handling to be able to handle 1.0. Also install TOSA pip package for 1.0 alongside 0.80. ### Test plan Validate that 0.80 TOSA version test still work with the 1.0 package installed. Signed-off-by: Per Åstrand <[email protected]>
1 parent eca61ff commit 8b9662c

15 files changed

+118
-44
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166166

167167
return self._transform(exported_program.graph_module)
168168

169+
def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
170+
return self._tosa_080_BI_pipeline(exported_program)
171+
172+
def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
173+
return self._tosa_080_MI_pipeline(exported_program)
174+
169175
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170176
"""Apply passes before transforming program to backend"""
171177
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
172178
return self._tosa_080_BI_pipeline(exported_program)
173179
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
174180
return self._tosa_080_MI_pipeline(exported_program)
181+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
182+
return self._tosa_1_0_fp_pipeline(exported_program)
183+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
184+
return self._tosa_1_0_int_quantized_pipeline(exported_program)
175185
else:
176186
raise NotImplementedError(
177187
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/operator_support/convolution_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
2222
tosa_specs = [
2323
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
26+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2527
]
2628

2729
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/minmax_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
2222
# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
2323
tosa_specs = [
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2526
]
2627

2728
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/pool_2d_support.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4141
tosa_specs = [
4242
TosaSpecification.create_from_string("TOSA-0.80+BI"),
4343
TosaSpecification.create_from_string("TOSA-0.80+MI"),
44+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
45+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
4446
]
4547

4648
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
@@ -94,6 +96,8 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
9496
tosa_specs = [
9597
TosaSpecification.create_from_string("TOSA-0.80+BI"),
9698
TosaSpecification.create_from_string("TOSA-0.80+MI"),
99+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
100+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
97101
]
98102

99103
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class SumSupported(SupportedTOSAOperatorCheck):
2121
tosa_specs = [
2222
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2323
TosaSpecification.create_from_string("TOSA-0.80+MI"),
24+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2426
]
2527

2628
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
2929
tosa_specs = [
3030
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3131
TosaSpecification.create_from_string("TOSA-0.80+MI"),
32+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
33+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3234
]
3335

3436
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/slice_copy_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class SliceCopySupported(SupportedTOSAOperatorCheck):
2525
tosa_specs = [
2626
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2727
TosaSpecification.create_from_string("TOSA-0.80+MI"),
28+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2830
]
2931

3032
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]

backends/arm/operator_support/to_copy_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
3030
tosa_specs = [
3131
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3232
TosaSpecification.create_from_string("TOSA-0.80+MI"),
33+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
34+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3335
]
3436

3537
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def is_node_tosa_supported(
6666
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
6767
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
6868
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
69+
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
70+
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
6971
}
7072

7173

backends/arm/operators/node_visitor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
# pyre-unsafe
77

8-
from typing import Dict, List
8+
from typing import Any, Dict, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.tosa_mapping import TosaArg
1413
from executorch.backends.arm.tosa_specification import TosaSpecification
1514
from torch.export import ExportedProgram
@@ -25,19 +24,26 @@ class NodeVisitor:
2524
# a specific TOSA version.
2625
# When all node_visitors has been refactored to target a specific
2726
# version, this list should be removed.
28-
tosa_specs = [
27+
tosa_specs_1_00 = [
28+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
30+
]
31+
32+
tosa_specs_0_80 = [
2933
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3034
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3135
]
3236

37+
tosa_specs = tosa_specs_0_80 + tosa_specs_1_00
38+
3339
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
3440
self._exported_program = exported_program
3541
self.tosa_spec = tosa_spec
3642

3743
def define_node(
3844
self,
3945
node: torch.fx.Node,
40-
tosa_graph: ts.TosaSerializer,
46+
tosa_graph: Any,
4147
inputs: List[TosaArg],
4248
output: TosaArg,
4349
) -> None:
@@ -48,6 +54,8 @@ def define_node(
4854
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
4955
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
5056
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
57+
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
58+
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
5159
}
5260

5361

0 commit comments

Comments
 (0)