Skip to content

Commit 0f066e0

Browse files
authored
Arm backend: use ArmCompileSpec in backend
Differential Revision: D82109615 Pull Request resolved: #14140
1 parent 0447ebd commit 0f066e0

File tree

7 files changed

+43
-107
lines changed

7 files changed

+43
-107
lines changed

backends/arm/ethosu/backend.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import final, List
1616

1717
from executorch.backends.arm.arm_vela import vela_compile
18+
from executorch.backends.arm.ethosu.compile_spec import EthosUCompileSpec
1819

1920
from executorch.backends.arm.tosa.backend import TOSABackend
2021
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
@@ -35,16 +36,13 @@ class EthosUBackend(BackendDetails):
3536

3637
@staticmethod
3738
def _compile_tosa_flatbuffer(
38-
tosa_flatbuffer: bytes, compile_spec: List[CompileSpec]
39+
tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec
3940
) -> bytes:
4041
"""
4142
Static helper method to do the compilation of the TOSA flatbuffer
4243
representation to a target specific binary stream.
4344
"""
44-
compile_flags = []
45-
for spec in compile_spec:
46-
if spec.key == "compile_flags":
47-
compile_flags.append(spec.value.decode())
45+
compile_flags = compile_spec.compiler_flags
4846

4947
if len(compile_flags) == 0:
5048
# Not testing for compile_flags correctness here, just that they are
@@ -64,10 +62,11 @@ def _compile_tosa_flatbuffer(
6462
@staticmethod
6563
def preprocess(
6664
edge_program: ExportedProgram,
67-
compile_spec: List[CompileSpec],
65+
compile_specs: List[CompileSpec],
6866
) -> PreprocessResult:
6967
logger.info(f"{EthosUBackend.__name__} preprocess")
7068

69+
compile_spec = EthosUCompileSpec.from_list(compile_specs)
7170
# deduce TOSA compile_spec from Ethos-U compile spec. We get a new
7271
# compile spec list, containing only elements relevant for the
7372
# TOSABackend.
@@ -77,7 +76,7 @@ def preprocess(
7776
# ('All backend implementation are final...'), so use composition instead.
7877
# preprocess returns the serialized TOSA flatbuffer in .processed_bytes,
7978
# which can be passed on to next compilation step.
80-
tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec)
79+
tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec)
8180

8281
binary = EthosUBackend._compile_tosa_flatbuffer(
8382
tosa_preprocess.processed_bytes, compile_spec

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@
55

66
import unittest
77

8-
from executorch.backends.arm.tosa.specification import (
9-
get_tosa_spec,
10-
Tosa_1_00,
11-
TosaSpecification,
12-
)
8+
from executorch.backends.arm.tosa.specification import Tosa_1_00, TosaSpecification
139

14-
from executorch.exir.backend.compile_spec_schema import CompileSpec
1510
from parameterized import parameterized # type: ignore[import-untyped]
1611

1712
test_valid_strings = [
@@ -43,14 +38,6 @@
4338
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
4439
]
4540

46-
test_compile_specs = [
47-
([CompileSpec("tosa_spec", "TOSA-1.0.0+INT".encode())],),
48-
]
49-
50-
test_compile_specs_no_version = [
51-
([CompileSpec("other_key", "some_value".encode())],),
52-
]
53-
5441

5542
class TestTosaSpecification(unittest.TestCase):
5643
"""Tests the TOSA specification class"""
@@ -74,19 +61,6 @@ def test_invalid_version_strings(self, version_string: str):
7461

7562
assert tosa_spec is None
7663

77-
@parameterized.expand(test_compile_specs) # type: ignore[misc]
78-
def test_create_from_compilespec(self, compile_specs: list[CompileSpec]):
79-
tosa_spec = get_tosa_spec(compile_specs)
80-
assert isinstance(tosa_spec, TosaSpecification)
81-
82-
@parameterized.expand(test_compile_specs_no_version) # type: ignore[misc]
83-
def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec]):
84-
tosa_spec = None
85-
with self.assertRaises(ValueError):
86-
tosa_spec = get_tosa_spec(compile_specs)
87-
88-
assert tosa_spec is None
89-
9064
@parameterized.expand(test_valid_strings)
9165
def test_correct_string_representation(self, version_string: str):
9266
tosa_spec = TosaSpecification.create_from_string(version_string)

backends/arm/test/tester/arm_tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def __init__(
303303
Args:
304304
model (torch.nn.Module): The model to test
305305
example_inputs (Tuple[torch.Tensor]): Example inputs to the model
306-
compile_spec (List[CompileSpec]): The compile spec to use
306+
compile_spec (ArmCompileSpec): The compile spec to use
307307
"""
308308

309309
self.transform_passes = transform_passes

backends/arm/tosa/backend.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
process_output,
2525
process_placeholder,
2626
)
27-
from executorch.backends.arm.tosa.specification import get_tosa_spec
27+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
2828
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2929
from executorch.exir.backend.compile_spec_schema import CompileSpec
3030
from torch.export.exported_program import ExportedProgram
@@ -80,38 +80,24 @@ class TOSABackend(BackendDetails):
8080
"""
8181

8282
@staticmethod
83-
def preprocess( # noqa: C901
83+
def preprocess(edge_program: ExportedProgram, compile_specs: List[CompileSpec]):
84+
return TOSABackend._preprocess(
85+
edge_program, TosaCompileSpec.from_list(compile_specs)
86+
)
87+
88+
@staticmethod
89+
def _preprocess( # noqa: C901
8490
edge_program: ExportedProgram,
85-
compile_spec: List[CompileSpec],
91+
compile_spec: TosaCompileSpec,
8692
) -> PreprocessResult:
8793
# if a debug/test build capture output files from TOSA stage
88-
artifact_path = None
89-
output_format = ""
90-
compile_flags = []
91-
dump_debug_info = None
92-
for spec in compile_spec:
93-
if spec.key == "debug_artifact_path":
94-
artifact_path = spec.value.decode()
95-
if spec.key == "output_format":
96-
output_format = spec.value.decode()
97-
if spec.key == "compile_flags":
98-
compile_flags.append(spec.value.decode())
99-
if spec.key == "dump_debug_info":
100-
dump_debug_info = spec.value.decode()
101-
102-
# Check that the output format is set correctly in the compile spec
103-
if output_format != "tosa":
104-
raise ValueError(f'Invalid output format {output_format}, must be "tosa"')
94+
artifact_path = compile_spec.get_intermediate_path()
95+
tosa_spec = compile_spec.tosa_spec
96+
dump_debug_info = compile_spec.tosa_debug_mode
10597

10698
# Assign to every node external id
10799
node_2_id = _annotate_external_ids(edge_program.graph)
108100

109-
tosa_spec = get_tosa_spec(compile_spec)
110-
if tosa_spec is None:
111-
raise ValueError(
112-
"TOSA backend needs a TOSA version specified in the CompileSpec"
113-
)
114-
115101
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
116102

117103
# Converted output for this subgraph, serializer needs path early as it emits
@@ -132,7 +118,7 @@ def preprocess( # noqa: C901
132118

133119
debug_hook = None
134120
if dump_debug_info is not None:
135-
debug_hook = DebugHook(ArmCompileSpec.DebugMode[dump_debug_info])
121+
debug_hook = DebugHook(dump_debug_info)
136122

137123
# TODO: Fix the need to lazily import this.
138124
from executorch.backends.arm.operators.node_visitor import get_node_visitors
@@ -204,8 +190,8 @@ def _sort_key(t: Node) -> int:
204190

205191
@staticmethod
206192
def filter_tosa_compile_specs(
207-
compile_spec: List[CompileSpec],
208-
) -> List[CompileSpec]:
193+
compile_spec: ArmCompileSpec,
194+
) -> TosaCompileSpec:
209195
"""
210196
Filter out the CompileSpec elements relevant for the TOSA backend.
211197
This is needed to compose a backend targetting hardware IP with the
@@ -214,17 +200,9 @@ def filter_tosa_compile_specs(
214200
flatbuffer can then be consumed by the backend targetting specific
215201
hardware.
216202
"""
217-
tosa_compile_spec = []
218-
tosa_compile_spec.append(CompileSpec("output_format", "tosa".encode()))
219-
220-
# Copy everything that's TOSA generic
221-
tosa_backend_compile_spec_keys = [
222-
"tosa_spec",
223-
"debug_artifact_path",
224-
]
225203

226-
for spec in compile_spec:
227-
if spec.key in tosa_backend_compile_spec_keys:
228-
tosa_compile_spec.append(CompileSpec(spec.key, spec.value))
229-
230-
return tosa_compile_spec
204+
new_compile_spec = TosaCompileSpec.__new__(TosaCompileSpec)
205+
new_compile_spec._set_compile_specs(
206+
compile_spec.tosa_spec, [], compile_spec.get_intermediate_path()
207+
)
208+
return new_compile_spec

backends/arm/tosa/partitioner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def __init__(
6565
self.delegation_spec = DelegationSpec(
6666
TOSABackend.__name__, compile_spec.to_list()
6767
)
68+
self.tosa_spec = compile_spec.tosa_spec
6869
self.additional_checks = additional_checks
6970
self.tosa_spec = compile_spec.tosa_spec
7071

@@ -75,13 +76,13 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # no
7576
logger.info("TOSAPartitioner::partition")
7677
partition_tags: dict[str, DelegationSpec] = {}
7778

78-
tosa_spec = self.tosa_spec
79-
80-
logger.info(f"Partitioning for {self.delegation_spec.backend_id}: {tosa_spec}")
79+
logger.info(
80+
f"Partitioning for {self.delegation_spec.backend_id}: {self.tosa_spec}"
81+
)
8182

8283
reporter = WhyNoPartitionReporter()
8384
operator_support = tosa_support_factory(
84-
tosa_spec, exported_program, reporter, self.additional_checks
85+
self.tosa_spec, exported_program, reporter, self.additional_checks
8586
)
8687
capability_partitioner = CapabilityBasedPartitioner(
8788
exported_program.graph_module,
@@ -131,7 +132,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
131132
break
132133
continue
133134

134-
if tosa_spec.support_float():
135+
if self.tosa_spec.support_float():
135136
continue
136137

137138
if is_partitioned(node):
@@ -163,7 +164,7 @@ def is_partitioned(node: torch.fx.Node, tag=tag) -> bool:
163164
)
164165

165166
tag_constant_data(exported_program)
166-
logger.info(f"The following nodes were rejected for {tosa_spec}:")
167+
logger.info(f"The following nodes were rejected for {self.tosa_spec}:")
167168
logger.info("\n" + reporter.get_table_report())
168169
logger.info("(Placeholders and outputs are not included in this list)")
169170
return PartitionResult(
@@ -213,8 +214,7 @@ def filter_fn(node: torch.fx.Node) -> bool:
213214
torch.ops.aten.logit.default,
214215
] + ops_to_not_decompose_if_quant_op
215216

216-
tosa_spec = self.tosa_spec
217-
if not tosa_spec.is_U55_subset:
217+
if not self.tosa_spec.is_U55_subset:
218218
# Tosa operator "RESIZE" is not supported on U55. Since upsample_bilinear2d
219219
# and upsample_nearest2d decompose into that it will not be possible to
220220
# delegate those operators on U55. If we have said here to not decompose

backends/arm/tosa/specification.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,6 @@
1515
import re
1616
from typing import List
1717

18-
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-not-found]
19-
CompileSpec,
20-
)
21-
2218
from packaging.version import Version
2319

2420

@@ -199,10 +195,3 @@ def get_context_spec() -> TosaSpecification:
199195
return TosaLoweringContext.tosa_spec_var.get()
200196
except LookupError:
201197
raise RuntimeError("Function must be executed within a TosaLoweringContext")
202-
203-
204-
def get_tosa_spec(compile_spec: List[CompileSpec]) -> TosaSpecification:
205-
for spec in compile_spec:
206-
if spec.key == "tosa_spec":
207-
return TosaSpecification.create_from_string(spec.value.decode())
208-
raise ValueError("Could not find TOSA version in CompileSpec")

backends/arm/vgf/backend.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
arm_get_first_delegation_tag,
2323
TOSABackend,
2424
)
25+
from executorch.backends.arm.vgf.compile_spec import VgfCompileSpec
2526
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
2627
from executorch.exir.backend.compile_spec_schema import CompileSpec
2728
from torch.export.exported_program import ExportedProgram
@@ -40,32 +41,27 @@ class VgfBackend(BackendDetails):
4041
@staticmethod
4142
def _compile_tosa_flatbuffer(
4243
tosa_flatbuffer: bytes,
43-
compile_spec: List[CompileSpec],
44+
compile_spec: VgfCompileSpec,
4445
tag_name: str = "",
4546
) -> bytes:
4647
"""
4748
Static helper method to do the compilation of the TOSA flatbuffer
4849
representation to a target specific binary stream.
4950
"""
50-
compile_flags = []
51-
artifact_path = None
52-
for spec in compile_spec:
53-
if spec.key == "compile_flags":
54-
compile_flags.append(spec.value.decode())
55-
if spec.key == "debug_artifact_path":
56-
artifact_path = spec.value.decode()
57-
51+
compile_flags = compile_spec.compiler_flags
52+
artifact_path = compile_spec.get_intermediate_path()
5853
# Pass on the TOSA flatbuffer to the vgf compiler.
5954
binary = vgf_compile(tosa_flatbuffer, compile_flags, artifact_path, tag_name)
6055
return binary
6156

6257
@staticmethod
6358
def preprocess(
6459
edge_program: ExportedProgram,
65-
compile_spec: List[CompileSpec],
60+
compile_specs: List[CompileSpec],
6661
) -> PreprocessResult:
6762
logger.info(f"{VgfBackend.__name__} preprocess")
6863

64+
compile_spec = VgfCompileSpec.from_list(compile_specs)
6965
# deduce TOSA compile_spec from VGF compile spec. We get a new
7066
# compile spec list, containing only elements relevant for the
7167
# TOSABackend.
@@ -75,7 +71,7 @@ def preprocess(
7571
# ('All backend implementation are final...'), so use composition instead.
7672
# preprocess returns the serialized TOSA flatbuffer in .processed_bytes,
7773
# which can be passed on to next compilation step.
78-
tosa_preprocess = TOSABackend.preprocess(edge_program, tosa_compile_spec)
74+
tosa_preprocess = TOSABackend._preprocess(edge_program, tosa_compile_spec)
7975

8076
tag_name = arm_get_first_delegation_tag(edge_program.graph_module)
8177

0 commit comments

Comments
 (0)