Skip to content

Commit 0081bef

Browse files
Arm backend: Add complie spec factories (pytorch#14376)
Signed-off-by: Erik Lundell <[email protected]> Co-authored-by: Digant Desai <[email protected]>
1 parent 871fe39 commit 0081bef

File tree

7 files changed

+121
-96
lines changed

7 files changed

+121
-96
lines changed

backends/arm/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,17 @@ runtime.python_library(
106106
"//caffe2:torch",
107107
]
108108
)
109+
runtime.python_library(
110+
name = "_factory",
111+
srcs = [
112+
"util/_factory.py"
113+
],
114+
deps = [
115+
":ethosu",
116+
":vgf",
117+
":arm_compile_spec",
118+
"//executorch/backends/arm/quantizer:lib",
119+
"//executorch/exir/backend:operator_support",
120+
"//executorch/exir/backend:compile_spec_schema",
121+
]
122+
)

backends/arm/test/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
16
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
27
load(":targets.bzl", "define_arm_tests")
38

@@ -58,6 +63,7 @@ runtime.python_library(
5863
"//executorch/backends/arm/quantizer:lib",
5964
"//executorch/backends/arm/tosa:mapping",
6065
"//executorch/backends/arm:vgf",
66+
"//executorch/backends/arm:_factory",
6167
"//executorch/devtools/backend_debug:delegation_info",
6268
"//executorch/exir/backend:operator_support",
6369
"fbsource//third-party/pypi/tabulate:tabulate",

backends/arm/test/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import pytest
1616
from executorch.backends.arm.ethosu import EthosUCompileSpec
17+
1718
from executorch.backends.arm.test.runner_utils import (
1819
arm_executor_runner_exists,
1920
corstone300_installed,

backends/arm/test/tester/arm_tester.py

Lines changed: 30 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,11 @@
2828

2929
import torch.fx
3030
import torch.utils._pytree as pytree
31-
3231
from executorch.backends.arm._passes.arm_pass_manager import ArmPassManager
3332

3433
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
35-
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
36-
from executorch.backends.arm.quantizer import (
37-
EthosUQuantizer,
38-
get_symmetric_quantization_config,
39-
TOSAQuantizer,
40-
VgfQuantizer,
41-
)
34+
from executorch.backends.arm.ethosu import EthosUCompileSpec
35+
from executorch.backends.arm.quantizer import get_symmetric_quantization_config
4236
from executorch.backends.arm.test.runner_utils import (
4337
dbg_tosa_fb_to_json,
4438
get_output_quantization_params,
@@ -53,9 +47,13 @@
5347
from executorch.backends.arm.tosa import TosaSpecification
5448
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
5549
from executorch.backends.arm.tosa.mapping import extract_tensor_meta
56-
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
5750

58-
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
51+
from executorch.backends.arm.util._factory import (
52+
create_partitioner,
53+
create_quantizer,
54+
parse_compile_spec,
55+
)
56+
from executorch.backends.arm.vgf import VgfCompileSpec
5957

6058
from executorch.backends.test.harness.error_statistics import ErrorStatistics
6159
from executorch.backends.test.harness.stages import Stage, StageType
@@ -83,7 +81,6 @@
8381
_copy_module,
8482
_update_exported_program_graph_module,
8583
)
86-
8784
from tabulate import tabulate
8885

8986
from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec
@@ -103,26 +100,20 @@ def _dump_lowered_modules_artifact(
103100
artifact.exported_program().graph_signature
104101
)
105102

106-
def get_output_format(lowered_module) -> str | None:
107-
for spec in lowered_module.compile_specs:
108-
if spec.key == "output_format":
109-
return spec.value.decode()
110-
return None
111-
112103
for node in graph_module.graph.nodes:
113104
if node.op == "get_attr" and node.name.startswith("lowered_module_"):
114105
lowered_module = getattr(graph_module, node.name)
115106
assert isinstance(
116107
lowered_module, LoweredBackendModule
117108
), f"Attribute {node.name} must be of type LoweredBackendModule."
118109

119-
output_format = get_output_format(lowered_module)
120-
if output_format == "tosa":
110+
compile_spec = parse_compile_spec(lowered_module.compile_specs)
111+
if isinstance(compile_spec, TosaCompileSpec):
121112
tosa_fb = lowered_module.processed_bytes
122113
to_print = dbg_tosa_fb_to_json(tosa_fb)
123114
to_print = pformat(to_print, compact=True, indent=1)
124115
output += f"\nTOSA deserialized {node.name}: \n{to_print}\n"
125-
elif output_format == EthosUCompileSpec.get_output_format():
116+
elif isinstance(compile_spec, EthosUCompileSpec):
126117
vela_cmd_stream = lowered_module.processed_bytes
127118
output += f"\nVela command stream {node.name}: \n{vela_cmd_stream}\n"
128119
else:
@@ -284,13 +275,7 @@ def quantize(
284275
quantize_stage: Optional[tester.Quantize] = None,
285276
):
286277
if quantize_stage is None:
287-
quantizer = None
288-
if isinstance(self.compile_spec, TosaCompileSpec):
289-
quantizer = TOSAQuantizer(self.compile_spec)
290-
elif isinstance(self.compile_spec, EthosUCompileSpec):
291-
quantizer = EthosUQuantizer(self.compile_spec)
292-
elif isinstance(self.compile_spec, VgfCompileSpec):
293-
quantizer = VgfQuantizer(self.compile_spec)
278+
quantizer = create_quantizer(self.compile_spec)
294279
quantize_stage = tester.Quantize(
295280
quantizer,
296281
get_symmetric_quantization_config(),
@@ -312,14 +297,7 @@ def to_edge(
312297

313298
def partition(self, partition_stage: Optional[Partition] = None):
314299
if partition_stage is None:
315-
if isinstance(self.compile_spec, TosaCompileSpec):
316-
arm_partitioner = TOSAPartitioner(self.compile_spec)
317-
elif isinstance(self.compile_spec, EthosUCompileSpec):
318-
arm_partitioner = EthosUPartitioner(self.compile_spec)
319-
elif isinstance(self.compile_spec, VgfCompileSpec):
320-
arm_partitioner = VgfPartitioner(self.compile_spec)
321-
else:
322-
raise ValueError("compile spec doesn't target any Arm Partitioner")
300+
arm_partitioner = create_partitioner(self.compile_spec)
323301
partition_stage = Partition(arm_partitioner)
324302
return super().partition(partition_stage)
325303

@@ -329,7 +307,7 @@ def to_edge_transform_and_lower(
329307
partitioners: Optional[List[Partitioner]] = None,
330308
edge_compile_config: Optional[EdgeCompileConfig] = None,
331309
additional_checks: Optional[
332-
List[Union[DontPartition | DontPartitionModule | DontPartitionName]]
310+
List[DontPartition | DontPartitionModule | DontPartitionName]
333311
] = None,
334312
transform_passes: Optional[
335313
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
@@ -343,20 +321,9 @@ def to_edge_transform_and_lower(
343321

344322
if to_edge_and_lower_stage is None:
345323
if partitioners is None:
346-
if isinstance(self.compile_spec, TosaCompileSpec):
347-
arm_partitioner = TOSAPartitioner(
348-
self.compile_spec, additional_checks
349-
)
350-
elif isinstance(self.compile_spec, EthosUCompileSpec):
351-
arm_partitioner = EthosUPartitioner(
352-
self.compile_spec, additional_checks
353-
)
354-
elif isinstance(self.compile_spec, VgfCompileSpec):
355-
arm_partitioner = VgfPartitioner(
356-
self.compile_spec, additional_checks
357-
)
358-
else:
359-
raise ValueError("compile spec doesn't target any Arm Partitioner")
324+
arm_partitioner = create_partitioner(
325+
self.compile_spec, additional_checks
326+
)
360327
partitioners = [arm_partitioner]
361328
to_edge_and_lower_stage = ToEdgeTransformAndLower(
362329
partitioners,
@@ -743,22 +710,19 @@ def _get_tosa_operator_distribution(
743710
op_list = []
744711
id = 0
745712
while lowered_module := getattr(graph_module, f"lowered_module_{id}", None):
746-
for spec in lowered_module.compile_specs:
747-
if spec.key != "output_format":
748-
continue
749-
if spec.value == b"tosa":
750-
tosa_fb = lowered_module.processed_bytes
751-
tosa_json = dbg_tosa_fb_to_json(tosa_fb)
752-
for region in tosa_json["regions"]:
753-
for block in region["blocks"]:
754-
op_list.extend(
755-
[operator["op"] for operator in block["operators"]]
756-
)
757-
break
758-
elif spec.value == EthosUCompileSpec.get_output_format().encode():
759-
return "Can not get operator distribution for Vela command stream."
760-
else:
761-
return f"Unknown output format '{spec.value}'."
713+
compile_spec = parse_compile_spec(lowered_module.compile_specs)
714+
if isinstance(compile_spec, TosaCompileSpec):
715+
tosa_fb = lowered_module.processed_bytes
716+
tosa_json = dbg_tosa_fb_to_json(tosa_fb)
717+
for region in tosa_json["regions"]:
718+
for block in region["blocks"]:
719+
op_list.extend([operator["op"] for operator in block["operators"]])
720+
elif isinstance(compile_spec, EthosUCompileSpec):
721+
return "Can not get operator distribution for Vela command stream."
722+
elif isinstance(compile_spec, VgfCompileSpec):
723+
return "Can not get operator distribution for VGF."
724+
else:
725+
return f"Unknown output format '{compile_spec.get_output_format()}'."
762726
id += 1
763727
if id == 0:
764728
return "No delegate with name 'lowered_module_0 found in graph module."

backends/arm/tosa/backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def filter_tosa_compile_specs(
206206
hardware.
207207
"""
208208

209-
new_compile_spec = TosaCompileSpec.__new__(TosaCompileSpec)
210-
new_compile_spec._set_compile_specs(
211-
compile_spec.tosa_spec, [], compile_spec.get_intermediate_path()
209+
return (
210+
TosaCompileSpec(compile_spec.tosa_spec)
211+
.dump_intermediate_artifacts_to(compile_spec.get_intermediate_path())
212+
.dump_debug_info(compile_spec.tosa_debug_mode)
212213
)
213-
return new_compile_spec

backends/arm/util/_factory.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
7+
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
8+
from executorch.backends.arm.quantizer import (
9+
EthosUQuantizer,
10+
TOSAQuantizer,
11+
VgfQuantizer,
12+
)
13+
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
14+
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
15+
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
16+
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from torch.fx.passes.operator_support import OperatorSupportBase
18+
19+
20+
def parse_compile_spec(compile_specs: list[CompileSpec]) -> ArmCompileSpec:
21+
output_format = None
22+
for spec in compile_specs:
23+
if spec.key == "output_format":
24+
output_format = spec.value.decode()
25+
break
26+
else:
27+
raise ValueError("Compile spec without output format.")
28+
if output_format == TosaCompileSpec.get_output_format():
29+
return TosaCompileSpec.from_list(compile_specs)
30+
if output_format == EthosUCompileSpec.get_output_format():
31+
return EthosUCompileSpec.from_list(compile_specs)
32+
if output_format == VgfCompileSpec.get_output_format():
33+
return VgfCompileSpec.from_list(compile_specs)
34+
raise ValueError(f"Unknown output format {output_format}")
35+
36+
37+
def create_partitioner(
38+
compile_spec: ArmCompileSpec,
39+
additional_checks: list[OperatorSupportBase] | None = None,
40+
):
41+
if isinstance(compile_spec, TosaCompileSpec):
42+
return TOSAPartitioner(compile_spec, additional_checks)
43+
elif isinstance(compile_spec, EthosUCompileSpec):
44+
return EthosUPartitioner(compile_spec, additional_checks)
45+
elif isinstance(compile_spec, VgfCompileSpec):
46+
return VgfPartitioner(compile_spec, additional_checks)
47+
else:
48+
raise ValueError("compile spec doesn't target any Arm Partitioner")
49+
50+
51+
def create_quantizer(compile_spec: ArmCompileSpec):
52+
if isinstance(compile_spec, TosaCompileSpec):
53+
return TOSAQuantizer(compile_spec)
54+
elif isinstance(compile_spec, EthosUCompileSpec):
55+
return EthosUQuantizer(compile_spec)
56+
elif isinstance(compile_spec, VgfCompileSpec):
57+
return VgfQuantizer(compile_spec)
58+
else:
59+
raise ValueError("compile spec doesn't target any Arm Quantizer")

examples/arm/aot_arm_compiler.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,18 @@
1818
import torch
1919
from examples.devtools.scripts.export_bundled_program import save_bundled_program
2020
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
21-
from executorch.backends.arm.ethosu import EthosUCompileSpec, EthosUPartitioner
22-
from executorch.backends.arm.quantizer import (
23-
EthosUQuantizer,
24-
get_symmetric_quantization_config,
25-
TOSAQuantizer,
26-
VgfQuantizer,
27-
)
21+
from executorch.backends.arm.ethosu import EthosUCompileSpec
22+
from executorch.backends.arm.quantizer import get_symmetric_quantization_config
2823
from executorch.backends.arm.tosa import TosaSpecification
2924
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
30-
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
25+
from executorch.backends.arm.util._factory import create_partitioner, create_quantizer
3126

3227
from executorch.backends.arm.util.arm_model_evaluator import (
3328
evaluate_model,
3429
evaluator_calibration_data,
3530
)
3631

37-
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
32+
from executorch.backends.arm.vgf import VgfCompileSpec
3833

3934
# To use Cortex-M backend
4035
from executorch.backends.cortex_m.passes.quantized_linear_fusion_pass import (
@@ -158,15 +153,8 @@ def quantize(
158153
export"""
159154
logging.info("Quantizing Model...")
160155
logging.debug(f"Original model: {model}")
161-
quantizer = None
162-
if isinstance(compile_specs, EthosUCompileSpec):
163-
quantizer = EthosUQuantizer(compile_specs)
164-
elif isinstance(compile_specs, TosaCompileSpec):
165-
quantizer = TOSAQuantizer(compile_specs)
166-
elif isinstance(compile_specs, VgfCompileSpec):
167-
quantizer = VgfQuantizer(compile_specs)
168-
else:
169-
raise RuntimeError("Unsupported compilespecs for quantization!")
156+
157+
quantizer = create_quantizer(compile_specs)
170158

171159
operator_config = get_symmetric_quantization_config()
172160
quantizer.set_global(operator_config)
@@ -649,14 +637,7 @@ def to_edge_TOSA_delegate(
649637
args, model, example_inputs, compile_spec
650638
)
651639

652-
if isinstance(compile_spec, EthosUCompileSpec):
653-
partitioner = EthosUPartitioner(compile_spec)
654-
elif isinstance(compile_spec, TosaCompileSpec):
655-
partitioner = TOSAPartitioner(compile_spec)
656-
elif isinstance(compile_spec, VgfCompileSpec):
657-
partitioner = VgfPartitioner(compile_spec)
658-
else:
659-
raise RuntimeError(f"Unhandled compile spec: {compile_spec}")
640+
partitioner = create_partitioner(compile_spec)
660641

661642
edge = to_edge_transform_and_lower(
662643
exported_program,

0 commit comments

Comments
 (0)