Skip to content

Commit ef8375d

Browse files
committed
Add TOSA specification handling to Arm backend
Mandate the need for a TOSA version in the compile spec list passed to the Arm backend and propagate the information to node visitors for serialization handling. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I43b35923f71a312e3064eab9388a4bd2756dc17f
1 parent 2cd27d4 commit ef8375d

File tree

6 files changed

+80
-17
lines changed

6 files changed

+80
-17
lines changed

backends/arm/arm_backend.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121
from executorch.backends.arm.operators.op_output import process_output
2222
from executorch.backends.arm.operators.op_placeholder import process_placeholder
23+
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2325
from executorch.backends.arm._passes.arm_pass_manager import (
2426
ArmPassManager,
2527
) # usort: skip
@@ -87,16 +89,23 @@ def ethosu_compile_spec(
8789
if extra_flags is not None:
8890
self.compiler_flags.append(extra_flags)
8991

92+
base_tosa_version = "TOSA-0.80.0+BI"
93+
if "U55" in config:
94+
# Add the Ethos-U55 extension marker
95+
base_tosa_version += "+u55"
96+
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
97+
9098
return self
9199

92-
def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
100+
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
93101
"""
94102
Generate compile spec for TOSA flatbuffer output
95103
"""
96104
assert (
97105
self.output_format is None
98106
), f"Output format already set: {self.output_format}"
99107
self.output_format = "tosa"
108+
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
100109
return self
101110

102111
def dump_intermediate_artifacts_to(
@@ -130,6 +139,13 @@ def build(self) -> List[CompileSpec]:
130139
"""
131140
Generate a list of compile spec objects from the builder
132141
"""
142+
assert self.tosa_version
143+
144+
# Always supply a TOSA version
145+
self.compile_spec = [
146+
CompileSpec("tosa_version", str(self.tosa_version).encode())
147+
]
148+
133149
if self.output_format == "vela":
134150
self.compile_spec += [
135151
CompileSpec("output_format", "vela".encode()),
@@ -211,26 +227,33 @@ def preprocess( # noqa: C901
211227
if not output_format:
212228
raise RuntimeError("output format is required")
213229

230+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
231+
assert (
232+
tosa_spec is not None
233+
), "TOSA backend needs a TOSA version specified in the CompileSpec!"
234+
214235
if output_format == "vela" and len(compile_flags) == 0:
215236
# Not testing for compile_flags correctness here, just that they are
216237
# present. The compiler will give errors if they are not valid.
217238
raise RuntimeError("compile flags are required for vela output format")
218239

240+
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
241+
219242
# Converted output for this subgraph, serializer needs path early as it emits
220243
# const data directly. Path created and data written only in debug builds.
221244
tosa_graph = ts.TosaSerializer(artifact_path)
222245
graph_module = ArmPassManager().transform_to_backend_pipeline(
223246
exported_program=edge_program, compile_spec=compile_spec
224247
)
225248

226-
node_visitors = get_node_visitors(edge_program)
249+
node_visitors = get_node_visitors(edge_program, tosa_spec)
227250

228251
for node in graph_module.graph.nodes:
229252
node = cast(Node, node)
230253
if node.op == "call_function":
231-
process_call_function(node, tosa_graph, node_visitors)
254+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
232255
elif node.op == "placeholder":
233-
process_placeholder(node, tosa_graph, edge_program)
256+
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
234257
elif node.op == "output":
235258
process_output(node, tosa_graph)
236259
else:

backends/arm/operators/node_visitor.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,6 +10,7 @@
1010
import serializer.tosa_serializer as ts
1111
import torch
1212
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from executorch.backends.arm.tosa_specification import TosaSpecification
1314
from torch.export import ExportedProgram
1415

1516

@@ -18,8 +19,19 @@ class NodeVisitor:
1819
Node Visitor pattern for lowering edge IR to TOSA
1920
"""
2021

21-
def __init__(self, exported_program: ExportedProgram):
22+
# Add the currently supported node_visitor specs as default.
23+
# This should be overriden in the NodeVisitor subclasses to target
24+
# a specific TOSA version.
25+
# When all node_visitors has been refactored to target a specific
26+
# version, this list should be removed.
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
30+
]
31+
32+
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
2233
self._exported_program = exported_program or None
34+
self.tosa_spec = tosa_spec
2335

2436
def define_node(
2537
self,
@@ -33,16 +45,30 @@ def define_node(
3345

3446

3547
# container for all node visitors
36-
_node_visitor_dict = {}
48+
_node_visitor_dicts = {
49+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
50+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
51+
}
3752

3853

3954
def register_node_visitor(visitor):
40-
_node_visitor_dict[visitor.target] = visitor
55+
for tosa_spec in visitor.tosa_specs:
56+
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
57+
return visitor
4158

4259

4360
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
4461
node_visitors = {}
45-
for target, visitor in _node_visitor_dict.items():
62+
tosa_spec = None
63+
for arg in args:
64+
if isinstance(arg, TosaSpecification):
65+
tosa_spec = arg
66+
break
67+
68+
if tosa_spec is None:
69+
raise RuntimeError("No TOSA specification supplied.")
70+
71+
for target, visitor in _node_visitor_dicts[tosa_spec].items():
4672
node_visitors[target] = visitor(*args)
4773

4874
return node_visitors

backends/arm/operators/op_placeholder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_quant_node_args,
1515
is_quant_arg,
1616
)
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
1718
from executorch.backends.arm.tosa_utils import (
1819
is_bias_node_for_quantized_addmm,
1920
is_bias_node_for_quantized_conv,
@@ -26,6 +27,7 @@
2627
def process_inputs(
2728
node: torch.fx.Node,
2829
tosa_graph: ts.TosaSerializer,
30+
tosa_spec: TosaSpecification,
2931
):
3032
"""Serialize an input node"""
3133
# inputs need to be in default dim_order (contiguous memory format)
@@ -95,6 +97,7 @@ def process_inputs_to_parameters(
9597
node: torch.fx.Node,
9698
tosa_graph: ts.TosaSerializer,
9799
edge_program: ExportedProgram,
100+
tosa_spec: TosaSpecification,
98101
):
99102
"""Serialize bias and non-quantized weights"""
100103
inputs = [TosaArg(node)]
@@ -106,9 +109,13 @@ def process_inputs_to_parameters(
106109

107110
if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node):
108111
# BI bias
112+
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
109113
process_quantized_bias(node, tosa_graph, parameter_values)
110114
else:
111115
# MI weights or bias
116+
if inputs[0].dtype == torch.float32:
117+
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
118+
112119
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
113120

114121
tosa_graph.addConst(
@@ -158,15 +165,16 @@ def process_placeholder(
158165
node: torch.fx.Node,
159166
tosa_graph: ts.TosaSerializer,
160167
edge_program: ExportedProgram,
168+
tosa_spec: TosaSpecification,
161169
):
162170
"""Wrapper for processing and serializing all types of placeholders"""
163171
assert node.name == node.target, "Expect placeholder name and target to match"
164172
assert 0 == len(node.args), "Can't handle default input values"
165173

166174
if node.name in edge_program.graph_signature.user_inputs:
167-
process_inputs(node, tosa_graph)
175+
process_inputs(node, tosa_graph, tosa_spec)
168176
elif node.name in edge_program.graph_signature.inputs_to_parameters:
169-
process_inputs_to_parameters(node, tosa_graph, edge_program)
177+
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
170178
elif node.name in edge_program.graph_signature.inputs_to_buffers:
171179
process_inputs_to_buffers(node, tosa_graph, edge_program)
172180
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:

backends/arm/test/common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None:
177177

178178

179179
def get_tosa_compile_spec(
180-
permute_memory_to_nhwc=True, custom_path=None
180+
tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
181181
) -> list[CompileSpec]:
182182
"""
183183
Default compile spec for TOSA tests.
184184
"""
185-
return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build()
185+
return get_tosa_compile_spec_unbuilt(
186+
tosa_version, permute_memory_to_nhwc, custom_path
187+
).build()
186188

187189

188190
def get_tosa_compile_spec_unbuilt(
189-
permute_memory_to_nhwc=False, custom_path=None
191+
tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
190192
) -> ArmCompileSpecBuilder:
191193
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
192194
the compile spec before calling .build() to finalize it.
@@ -198,7 +200,7 @@ def get_tosa_compile_spec_unbuilt(
198200
os.makedirs(custom_path, exist_ok=True)
199201
compile_spec_builder = (
200202
ArmCompileSpecBuilder()
201-
.tosa_compile_spec()
203+
.tosa_compile_spec(tosa_version)
202204
.set_permute_memory_format(permute_memory_to_nhwc)
203205
.dump_intermediate_artifacts_to(custom_path)
204206
)

backends/arm/tosa_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
is_quant_node,
2222
q_op,
2323
)
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2425
from executorch.exir.dialects._ops import ops as exir_ops
2526
from serializer.tosa_serializer import TosaOp
2627
from torch.fx import Node
@@ -290,6 +291,7 @@ def process_call_function(
290291
node: torch.fx.Node,
291292
tosa_graph: ts.TosaSerializer,
292293
node_visitors: Dict[str, NodeVisitor],
294+
tosa_spec: TosaSpecification,
293295
):
294296
# Unpack arguments and convert
295297
inputs = getNodeArgs(node)
@@ -319,7 +321,7 @@ def process_call_function(
319321
is_quant_node(node),
320322
)
321323
else:
322-
raise RuntimeError(f"Unknown operator {node.target}")
324+
raise RuntimeError(f"Unknown operator {node.target} for TOSA : {tosa_spec}")
323325

324326

325327
def expand_dims(

examples/arm/aot_arm_compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ def get_compile_spec(
180180
spec_builder = None
181181
if target == "TOSA":
182182
spec_builder = (
183-
ArmCompileSpecBuilder().tosa_compile_spec().set_permute_memory_format(True)
183+
ArmCompileSpecBuilder()
184+
.tosa_compile_spec("TOSA-0.80.0+BI")
185+
.set_permute_memory_format(True)
184186
)
185187
elif "ethos-u55" in target:
186188
spec_builder = (

0 commit comments

Comments
 (0)