Skip to content

Commit 105498e

Browse files
Github Executorchpsiddh
authored andcommitted
Add Cortex-M as a first-class target in aot_arm_compiler
Previously, Cortex-M op conversion was applied as an afterthought to all non-vgf targets via transform_for_cortex_m_backend(). This made the flow hard to follow, used a bare EdgeCompileConfig that decomposed ops like linear into addmm (requiring unnecessary workarounds), and didn't use the CortexMQuantizer or CortexMPassManager. Add a dedicated to_edge_cortex_m() path selected via --target=cortex-m that owns the full pipeline: CortexMQuantizer for INT8 quantization, correct EdgeCompileConfig with preserve_ops to prevent premature decomposition, and CortexMPassManager.pass_list for op conversion. Remove the old scattered transform_for_cortex_m_backend() function. Verified all ops fully lowered to cortex_m::quantized_* operators for both MobileNetV2 (70 nodes) and MobileNetV3 (122 nodes). E2E inference tested on Alif E8 board. Test Plan: python3 -m examples.arm.aot_arm_compiler -m mv2 --target=cortex-m55+int8 --quantize --intermediates=./mv2_intermediates --output=./mv2_cortex_m.pte python3 -m examples.arm.aot_arm_compiler -m mv3 --target=cortex-m55+int8 --quantize --intermediates=./mv3_intermediates --output=./mv3_cortex_m.pte Also ran E2E inference on Alif E8 board
1 parent f30d5ed commit 105498e

File tree

2 files changed

+80
-41
lines changed

2 files changed

+80
-41
lines changed

backends/cortex_m/passes/convert_to_cortex_m_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _get_batch_size_from_conv(self, conv_node: torch.fx.Node):
7171

7272
def _get_linear_replacement(self, node):
7373
"""
74-
Let
74+
Let
7575
- yi be the output activations (y1, ... yn)
7676
- xj be the input activations (x1, ... xm)
7777
- wij be the weights (w11, ... wnm)

examples/arm/aot_arm_compiler.py

Lines changed: 79 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,8 @@
3636
from executorch.backends.arm.util._factory import create_partitioner, create_quantizer
3737

3838
from executorch.backends.arm.vgf import VgfCompileSpec
39-
40-
# To use Cortex-M backend
41-
from executorch.backends.cortex_m.passes.convert_to_cortex_m_pass import (
42-
ConvertToCortexMPass,
43-
)
44-
45-
from executorch.backends.cortex_m.passes.quantized_op_fusion_pass import (
46-
QuantizedOpFusionPass,
47-
)
48-
49-
from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
50-
ReplaceQuantNodesPass,
51-
)
39+
from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager
40+
from executorch.backends.cortex_m.quantizer.quantizer import CortexMQuantizer
5241

5342
from executorch.devtools import generate_etrecord
5443
from executorch.devtools.backend_debug import get_delegation_info
@@ -396,6 +385,7 @@ def forward(self, x):
396385
"TOSA-1.0+INT",
397386
"TOSA-1.0+FP",
398387
"TOSA-1.0+INT+int16",
388+
"cortex-m55+int8",
399389
]
400390

401391

@@ -528,7 +518,7 @@ def get_args():
528518
required=False,
529519
default="ethos-u55-128",
530520
choices=TARGETS,
531-
help=f"For ArmBackend delegated models, pick the target, and therefore the instruction set generated. valid targets are {TARGETS}",
521+
help=f"Target backend. For delegated models: Ethos-U/VGF/TOSA variants. For non-delegated: cortex-m55+int8 (CMSIS-NN portable kernels). Valid targets: {TARGETS}",
532522
)
533523
parser.add_argument(
534524
"-e",
@@ -790,6 +780,75 @@ def to_edge_TOSA_delegate(
790780
return model_quant, edge
791781

792782

783+
def to_edge_cortex_m(
784+
exported_program: ExportedProgram,
785+
args,
786+
model: GraphModule,
787+
example_inputs: Tuple[torch.Tensor],
788+
):
789+
"""Cortex-M/CMSIS-NN compilation path with no delegation."""
790+
logging.info("Using Cortex-M/CMSIS-NN compilation path (no delegation)")
791+
792+
def _to_channels_last(x):
793+
if isinstance(x, torch.Tensor):
794+
if x.dim() == 4 and not x.is_contiguous(memory_format=torch.channels_last):
795+
logging.warning(
796+
"Converting input tensor with shape %s to channels_last",
797+
list(x.shape),
798+
)
799+
return x.to(memory_format=torch.channels_last)
800+
return x
801+
elif isinstance(x, tuple):
802+
return tuple(_to_channels_last(t) for t in x)
803+
return x
804+
805+
if not args.quantize:
806+
logging.warning(
807+
"Quantization is DISABLED. Cortex-M typically requires quantization."
808+
)
809+
else:
810+
model = model.to(memory_format=torch.channels_last)
811+
example_inputs = tuple(_to_channels_last(x) for x in example_inputs)
812+
813+
quantizer = CortexMQuantizer()
814+
prepared = prepare_pt2e(model, quantizer)
815+
816+
dataset = get_calibration_data(
817+
args.model_name, example_inputs, args.evaluate, args.evaluate_config
818+
)
819+
820+
if isinstance(dataset, DataLoader):
821+
for sample, _ in dataset:
822+
prepared(_to_channels_last(sample))
823+
else:
824+
prepared(*tuple(_to_channels_last(x) for x in dataset))
825+
826+
model_quant = convert_pt2e(prepared)
827+
828+
exported_program = torch.export.export(
829+
model_quant, example_inputs, strict=args.strict_export
830+
)
831+
832+
edge = to_edge_transform_and_lower(
833+
exported_program,
834+
compile_config=EdgeCompileConfig(
835+
preserve_ops=[
836+
torch.ops.aten.linear.default,
837+
torch.ops.aten.hardsigmoid.default,
838+
torch.ops.aten.hardsigmoid_.default,
839+
torch.ops.aten.hardswish.default,
840+
torch.ops.aten.hardswish_.default,
841+
],
842+
_check_ir_validity=False,
843+
),
844+
)
845+
846+
pass_manager = CortexMPassManager(edge.exported_program())
847+
edge._edge_programs["forward"] = pass_manager.transform()
848+
849+
return model_quant if args.quantize else None, edge
850+
851+
793852
def to_edge_no_delegate(
794853
exported_program: ExportedProgram,
795854
args,
@@ -825,26 +884,6 @@ def to_edge_no_delegate(
825884
return model_quant, edge
826885

827886

828-
def transform_for_cortex_m_backend(edge_program_manager, args):
829-
# Let's make sure we are using optimized Cortex M backend
830-
# NB: If we can't find and replace ops those are expected to be replaced,
831-
# bad things will happen at runtime, like "missing operator" errors!
832-
833-
# Instantiate the mandatory ReplaceQuantNodesPass
834-
passes = [ReplaceQuantNodesPass]
835-
if args.enable_qdq_fusion_pass:
836-
passes += [ConvertToCortexMPass, QuantizedOpFusionPass]
837-
current_edge = edge_program_manager
838-
for pass_cls in passes:
839-
transform_pass = (
840-
pass_cls(current_edge.exported_program())
841-
if pass_cls.__name__ == "QuantizedLinearFusionPass"
842-
else pass_cls()
843-
)
844-
current_edge = current_edge.transform([transform_pass])
845-
return current_edge
846-
847-
848887
if __name__ == "__main__": # noqa: C901
849888
args = get_args()
850889

@@ -876,7 +915,12 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
876915

877916
# Quantize if required
878917
model_quant = None
879-
if args.delegate:
918+
if args.target == "cortex-m55+int8":
919+
# Cortex-M path: CMSIS-NN portable kernels, no delegation
920+
model_quant, edge = to_edge_cortex_m(
921+
exported_program, args, model, example_inputs
922+
)
923+
elif args.delegate:
880924
model_quant, edge = to_edge_TOSA_delegate(
881925
exported_program, args, model, example_inputs
882926
)
@@ -885,11 +929,6 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
885929
exported_program, args, model, example_inputs
886930
)
887931

888-
# Cortex-m ops are never included in vgf or direct-drive
889-
if args.target != "vgf" and not args.direct_drive:
890-
# Transform so we can use ops from the Cortex M backend
891-
edge = transform_for_cortex_m_backend(edge, args)
892-
893932
dump_delegation_info(edge, args.intermediates)
894933

895934
edge_program_manager_copy = copy.deepcopy(edge)

0 commit comments

Comments
 (0)