Skip to content

Commit 5b09b23

Browse files
authored
Merge branch 'main' into fix-typo-llm-export-tutorial
2 parents 8731dfb + b4d72f1 commit 5b09b23

37 files changed

+1580
-236
lines changed

backends/arm/_passes/decompose_int16_activation_conv2d_pass.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
1111
from executorch.backends.arm._passes.quant_args import QuantArgs
1212

13-
from executorch.backends.arm.tosa.specification import get_context_spec, Tosa_1_00
13+
from executorch.backends.arm.tosa.specification import get_context_spec
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from executorch.exir.pass_base import ExportPass
1616

@@ -40,9 +40,7 @@ def call_operator(self, op, args, kwargs, meta):
4040
if args[0].data.dtype == torch.int8:
4141
return super().call_operator(op, args, kwargs, meta)
4242
elif args[0].data.dtype == torch.int16:
43-
if isinstance(tosa_spec, Tosa_1_00) and not tosa_spec.support_extension(
44-
"int16"
45-
):
43+
if not tosa_spec.support_extension("int16"):
4644
raise ValueError(
4745
"int16 activation for convolution requires TOSA int16 extension"
4846
)

backends/arm/common/arm_compile_spec.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,22 @@ class DebugMode(Enum):
3535
_OUTPUT_FORMAT_KEY = "output_format"
3636
_DEBUG_ARTIFACT_KEY = "debug_artifact_path"
3737
_DEBUG_MODE_KEY = "dump_debug_info"
38+
_OUTPUT_REORDER_KEY = "ouput_reorder_workaround"
3839

3940
def _set_compile_specs(
4041
self,
4142
tosa_spec: TosaSpecification,
4243
compiler_flags: list[str],
4344
path_for_intermediates: str | None = None,
4445
tosa_debug_mode: DebugMode | None = None,
46+
output_order_workaround: bool = True,
4547
):
4648
"""Set all values of dataclass directly."""
4749
self.tosa_spec = tosa_spec
4850
self.compiler_flags = compiler_flags
4951
self.path_for_intermediates = path_for_intermediates
5052
self.tosa_debug_mode = tosa_debug_mode
53+
self.output_order_workaround = output_order_workaround
5154

5255
@classmethod
5356
def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
@@ -56,10 +59,15 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
5659
compiler_flags: list[str] | None = None
5760
path_for_intermediates: str | None = None
5861
tosa_debug_mode: ArmCompileSpec.DebugMode | None = None
62+
output_order_workaround: bool = True
5963
unknown_specs: dict[str, str] = {}
6064
for spec in compile_specs:
6165
key = spec.key
62-
val = spec.value.decode()
66+
val = (
67+
spec.value.decode()
68+
if isinstance(spec.value, (bytes, bytearray))
69+
else spec.value
70+
)
6371
if key == ArmCompileSpec._TOSA_SPEC_KEY:
6472
if tosa_spec is not None:
6573
raise ValueError("More than one tosa_spec entry in compile spec.")
@@ -88,6 +96,8 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
8896
"More than one tosa_debug_mode entry in compile spec."
8997
)
9098
tosa_debug_mode = ArmCompileSpec.DebugMode[val]
99+
elif key == ArmCompileSpec._OUTPUT_REORDER_KEY:
100+
output_order_workaround = val # type: ignore[assignment]
91101
else:
92102
unknown_specs[key] = val
93103

@@ -109,6 +119,7 @@ def from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901
109119
compiler_flags=compiler_flags,
110120
path_for_intermediates=path_for_intermediates,
111121
tosa_debug_mode=tosa_debug_mode,
122+
output_order_workaround=output_order_workaround,
112123
)
113124
cls.from_list_hook(compile_spec, unknown_specs)
114125
compile_spec.validate()
@@ -170,6 +181,14 @@ def to_list(self):
170181
)
171182
)
172183

184+
if not self.output_order_workaround:
185+
compile_spec.append(
186+
CompileSpec(
187+
ArmCompileSpec._OUTPUT_REORDER_KEY,
188+
self.output_order_workaround,
189+
)
190+
)
191+
173192
return compile_spec
174193

175194
def get_intermediate_path(self) -> str | None:
@@ -201,6 +220,13 @@ def dump_debug_info(self, debug_mode: DebugMode | None):
201220
self.tosa_debug_mode = debug_mode
202221
return self
203222

223+
def set_output_order_workaround(self, output_order_workaround: bool):
224+
self.output_order_workaround = output_order_workaround
225+
return self
226+
227+
def get_output_order_workaround(self) -> bool:
228+
return self.output_order_workaround
229+
204230
@classmethod
205231
@abstractmethod
206232
def get_output_format(cls) -> str:

backends/arm/ethosu/backend.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# backends. Converts via TOSA as an intermediate form supported by AoT and
1010
# JIT compiler flows.
1111
#
12+
"""Ahead-of-time Arm Ethos-U backend built on the shared TOSA pipeline."""
1213

1314
import logging
1415
from typing import final, List
@@ -27,19 +28,28 @@
2728

2829
@final
2930
class EthosUBackend(BackendDetails):
30-
"""
31-
BackendDetails subclass for delegation to Ethos-U. Deduce the TOSA lowering from
32-
the compile spec list by filtering out the compile spec values that are of interest
33-
for the TOSABackend.
31+
"""BackendDetails subclass for delegation to Ethos-U.
32+
33+
Deduce the TOSA lowering from the compile spec list by filtering out the
34+
compile spec values that are of interest for the TOSABackend.
35+
3436
"""
3537

3638
@staticmethod
3739
def _compile_tosa_flatbuffer(
3840
tosa_flatbuffer: bytes, compile_spec: EthosUCompileSpec
3941
) -> bytes:
40-
"""
41-
Static helper method to do the compilation of the TOSA flatbuffer
42-
representation to a target specific binary stream.
42+
"""Compile a TOSA flatbuffer into a target-specific binary stream.
43+
44+
Args:
45+
tosa_flatbuffer (bytes): Serialized TOSA graph produced by
46+
``TOSABackend``.
47+
compile_spec (EthosUCompileSpec): Compile specification providing
48+
Vela flags and intermediate paths.
49+
50+
Returns:
51+
bytes: Target-specific binary stream produced by Vela.
52+
4353
"""
4454
compile_flags = compile_spec.compiler_flags
4555

@@ -73,6 +83,17 @@ def preprocess(
7383
edge_program: ExportedProgram,
7484
compile_specs: List[CompileSpec],
7585
) -> PreprocessResult:
86+
"""Lower the exported program and compile it for an Ethos-U target.
87+
88+
Args:
89+
edge_program (ExportedProgram): Program to lower to Ethos-U.
90+
compile_specs (List[CompileSpec]): Serialized Ethos-U compile specs
91+
supplied by the frontend.
92+
93+
Returns:
94+
PreprocessResult: Result containing the compiled Ethos-U binary.
95+
96+
"""
7697
logger.info(f"{EthosUBackend.__name__} preprocess")
7798

7899
compile_spec = EthosUCompileSpec.from_list(compile_specs)

backends/arm/quantizer/arm_quantizer.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from torchao.quantization.pt2e.quantizer import (
4747
annotate_input_qspec_map,
4848
annotate_output_qspec,
49+
get_module_name_filter,
4950
QuantizationSpec,
5051
Quantizer,
5152
)
@@ -248,33 +249,6 @@ def get_symmetric_a16w8_quantization_config(
248249
"""
249250

250251

251-
def _get_module_name_filter(module_name: str) -> NodeFilterType:
252-
"""Get the module_name_filter function for a given module name, the filter accepts
253-
a node and checks if the node comes from a module that has certain module name
254-
255-
For example:
256-
node: linear_op = call_function[...](...) # comes from a module with name blocks.sub.linear1
257-
258-
>> module_name_filter = _get_module_name_filter("blocks.sub")
259-
>> print(module_name_filter(node))
260-
True # the node is from "blocks.sub" based on the fully qualified name "blocks.sub.linear1"
261-
"""
262-
263-
name_start = len("L['self'].")
264-
265-
def module_name_filter(n: Node) -> bool:
266-
# node_stack example: {
267-
# 'L__self___sub': ("L['self'].sub", <class '....Sub'>),
268-
# 'L__self___sub_linear': ("L['self'].sub.linear", <class 'torch.nn.modules.linear.Linear'>)
269-
# }
270-
# get_attr nodes doesn't have nn_module_stack?
271-
nn_module_stack = n.meta.get("nn_module_stack", {})
272-
names = [name[name_start:] for name, _ in nn_module_stack.values()]
273-
return module_name in names
274-
275-
return module_name_filter
276-
277-
278252
def _get_module_type_filter(tp: Callable) -> NodeFilterType:
279253
"""Get the module_type_filter function for a given module type, the filter accepts
280254
a node and checks if the node comes from a module that has certain module type
@@ -306,7 +280,7 @@ def _get_not_module_type_or_name_filter(
306280
tp_list: List[Callable], module_name_list: List[str]
307281
) -> NodeFilterType:
308282
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
309-
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
283+
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
310284

311285
def not_module_type_or_name_filter(n: Node) -> bool:
312286
return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -455,7 +429,7 @@ def _annotate_for_static_quantization_config(
455429
module_name_list = list(self.module_name_config.keys())
456430
for module_name, config in self.module_name_config.items():
457431
self._annotate_all_static_patterns(
458-
model, config, _get_module_name_filter(module_name)
432+
model, config, get_module_name_filter(module_name)
459433
)
460434

461435
tp_list = list(self.module_type_config.keys())

backends/arm/test/misc/test_outputs_order.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def _read_tosa_outputs(tosa_path: Path):
7878
return shapes
7979

8080

81+
# TODO: MLETORCH-1266 Investigate output order issue
8182
@pytest.mark.parametrize("batch_size", [1, 4])
82-
def test_network_output_order_and_restore(batch_size):
83+
@pytest.mark.parametrize("output_order_workaround", [True, False])
84+
def test_network_output_order_and_restore(batch_size, output_order_workaround):
8385
model = Network(batch_norm=True).eval()
8486
# Prepare spec
8587
spec = TosaSpecification.create_from_string("TOSA-1.0+INT")
86-
compile_spec = TosaCompileSpec(tosa_spec=spec)
88+
tosa_compile_spec = TosaCompileSpec(spec).set_output_order_workaround(
89+
output_order_workaround
90+
)
8791
# Setup quantizer
88-
quantizer = TOSAQuantizer(compile_spec)
92+
quantizer = TOSAQuantizer(tosa_compile_spec)
8993
quantizer.set_global(
9094
get_symmetric_quantization_config(is_qat=True, is_per_channel=False)
9195
)
@@ -100,7 +104,7 @@ def test_network_output_order_and_restore(batch_size):
100104
with tempfile.TemporaryDirectory(dir="") as tmpdir:
101105
art_dir = Path(tmpdir)
102106
part = TOSAPartitioner(
103-
TosaCompileSpec(spec).dump_intermediate_artifacts_to(str(art_dir))
107+
tosa_compile_spec.dump_intermediate_artifacts_to(str(art_dir))
104108
)
105109
_ = to_edge_transform_and_lower(aten_gm, partitioner=[part])
106110
# Expect exactly one .tosa file in the artefact dir

0 commit comments

Comments
 (0)