Skip to content

Commit 1fa8a8c

Browse files
author
Github Executorch
committed
Add Cortex-M as a first-class target in aot_arm_compiler
Previously, Cortex-M op conversion was applied as an afterthought to all non-vgf targets via transform_for_cortex_m_backend(). This made the flow hard to follow, used a bare EdgeCompileConfig that decomposed ops like linear into addmm (requiring unnecessary workarounds), and didn't use the CortexMQuantizer or CortexMPassManager. Add a dedicated to_edge_cortex_m() path selected via --target=cortex-m that owns the full pipeline: CortexMQuantizer for INT8 quantization, correct EdgeCompileConfig with preserve_ops to prevent premature decomposition, and CortexMPassManager.pass_list for op conversion. Remove the old scattered transform_for_cortex_m_backend() function. Verified all ops fully lowered to cortex_m::quantized_* operators for both MobileNetV2 (70 nodes) and MobileNetV3 (122 nodes). E2E inference tested on Alif E8 board. Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m-int8 --quantize --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte python3 -m examples.arm.aot_arm_compiler -m mv3 --target=cortex-m-int8 --quantize --intermediates=./mv3_intermediates --output=./mv3_cortex_m.pte Also ran E2E inference on Alif E8 board
1 parent f48a600 commit 1fa8a8c

File tree

2 files changed

+79
-40
lines changed

2 files changed

+79
-40
lines changed

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
7171

7272
def _get_linear_replacement(self, node):
7373
"""
74-
Let
74+
Let
7575
- yi be the output activations (y1, ... yn)
7676
- xj be the input activations (x1, ... xm)
7777
- wij be the weights (w11, ... wnm)

examples/arm/aot_arm_compiler.py

Lines changed: 78 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,8 @@
3636
from executorch.backends.arm.util._factory import create_partitioner, create_quantizer
3737

3838
from executorch.backends.arm.vgf import VgfCompileSpec
39-
40-
# To use Cortex-M backend
41-
from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import (
42-
ConvertToCortexMPass,
43-
)
44-
45-
from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (
46-
QuantizedOpFusionPass,
47-
)
48-
49-
from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
50-
ReplaceQuantNodesPass,
51-
)
39+
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
40+
from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
5241

5342
from executorch.devtools import generate_etrecord
5443
from executorch.devtools.backend_debug import get_delegation_info
@@ -399,6 +388,7 @@ def forward(self, x):
399388
"TOSA-1.0+INT",
400389
"TOSA-1.0+FP",
401390
"TOSA-1.0+INT+int16",
391+
"cortex-m-int8",
402392
]
403393

404394

@@ -795,6 +785,75 @@ def to_edge_TOSA_delegate(
795785
return model_quant, edge
796786

797787

788+
def to_edge_cortex_m(
789+
exported_program: ExportedProgram,
790+
args,
791+
model: GraphModule,
792+
example_inputs: Tuple[torch.Tensor],
793+
):
794+
"""Cortex-M/CMSIS-NN compilation path with no delegation."""
795+
logging.info("Using Cortex-M/CMSIS-NN compilation path (no delegation)")
796+
797+
def _to_channels_last(x):
798+
if isinstance(x, torch.Tensor):
799+
if x.dim() == 4 and not x.is_contiguous(memory_format=torch.channels_last):
800+
logging.warning(
801+
"Converting input tensor with shape %s to channels_last",
802+
list(x.shape),
803+
)
804+
return x.to(memory_format=torch.channels_last)
805+
return x
806+
elif isinstance(x, tuple):
807+
return tuple(_to_channels_last(t) for t in x)
808+
return x
809+
810+
if not args.quantize:
811+
logging.warning(
812+
"Quantization is DISABLED. Cortex-M typically requires quantization."
813+
)
814+
else:
815+
model = model.to(memory_format=torch.channels_last)
816+
example_inputs = tuple(_to_channels_last(x) for x in example_inputs)
817+
818+
quantizer = CortexMQuantizer()
819+
prepared = prepare_pt2e(model, quantizer)
820+
821+
dataset = get_calibration_data(
822+
args.model_name, example_inputs, args.evaluate, args.evaluate_config
823+
)
824+
825+
if isinstance(dataset, DataLoader):
826+
for sample, _ in dataset:
827+
prepared(_to_channels_last(sample))
828+
else:
829+
prepared(*tuple(_to_channels_last(x) for x in dataset))
830+
831+
model_quant = convert_pt2e(prepared)
832+
833+
exported_program = torch.export.export(
834+
model_quant, example_inputs, strict=args.strict_export
835+
)
836+
837+
edge = to_edge_transform_and_lower(
838+
exported_program,
839+
compile_config=EdgeCompileConfig(
840+
preserve_ops=[
841+
torch.ops.aten.linear.default,
842+
torch.ops.aten.hardsigmoid.default,
843+
torch.ops.aten.hardsigmoid_.default,
844+
torch.ops.aten.hardswish.default,
845+
torch.ops.aten.hardswish_.default,
846+
],
847+
_check_ir_validity=False,
848+
),
849+
)
850+
851+
pass_manager = CortexMPassManager(edge.exported_program())
852+
edge._edge_programs["forward"] = pass_manager.transform()
853+
854+
return model_quant if args.quantize else None, edge
855+
856+
798857
def to_edge_no_delegate(
799858
exported_program: ExportedProgram,
800859
args,
@@ -830,26 +889,6 @@ def to_edge_no_delegate(
830889
return model_quant, edge
831890

832891

833-
def transform_for_cortex_m_backend(edge_program_manager, args):
834-
# Let's make sure we are using optimized Cortex M backend
835-
# NB: If we can't find and replace ops those are expected to be replaced,
836-
# bad things will happen at runtime, like "missing operator" errors!
837-
838-
# Instantiate the mandatory ReplaceQuantNodesPass
839-
passes = [ReplaceQuantNodesPass]
840-
if args.enable_qdq_fusion_pass:
841-
passes += [ConvertToCortexMPass, QuantizedOpFusionPass]
842-
current_edge = edge_program_manager
843-
for pass_cls in passes:
844-
transform_pass = (
845-
pass_cls(current_edge.exported_program())
846-
if pass_cls.__name__ == "QuantizedLinearFusionPass"
847-
else pass_cls()
848-
)
849-
current_edge = current_edge.transform([transform_pass])
850-
return current_edge
851-
852-
853892
if __name__ == "__main__": # noqa: C901
854893
args = get_args()
855894

@@ -881,7 +920,12 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
881920

882921
# Quantize if required
883922
model_quant = None
884-
if args.delegate:
923+
if args.target == "cortex-m-int8":
924+
# Cortex-M path: CMSIS-NN portable kernels, no delegation
925+
model_quant, edge = to_edge_cortex_m(
926+
exported_program, args, model, example_inputs
927+
)
928+
elif args.delegate:
885929
model_quant, edge = to_edge_TOSA_delegate(
886930
exported_program, args, model, example_inputs
887931
)
@@ -890,11 +934,6 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
890934
exported_program, args, model, example_inputs
891935
)
892936

893-
# Cortex-m ops are never included in vgf or direct-drive
894-
if args.target != "vgf" and not args.direct_drive:
895-
# Transform so we can use ops from the Cortex M backend
896-
edge = transform_for_cortex_m_backend(edge, args)
897-
898937
dump_delegation_info(edge, args.intermediates)
899938

900939
edge_program_manager_copy = copy.deepcopy(edge)

0 commit comments

Comments
 (0)