Skip to content

Commit c70aeda

Browse files
authored
Cortex_m backend: Loosen edge op check. (#13550)
The pass checked that all ops were edge ops to detect if the pass was ran before lowering to edge. However, there are cases where aten ops survive after edge lowering, notably torch.ops.tensor_scalar. This shouldn't crash the pass. Instead, only check that q/dq ops are edge ops. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Erik Lundell <[email protected]>
1 parent 2733de8 commit c70aeda

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

backends/cortex_m/passes/replace_quant_nodes_pass.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -10,7 +11,6 @@
1011
import torch
1112

1213
from executorch.exir.dialects._ops import ops as exir_ops
13-
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1414
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1515

1616

@@ -40,6 +40,10 @@ def __init__(self):
4040
"qualifier": self._is_qualified_int8_node,
4141
},
4242
}
43+
self.disallowed_targets = {
44+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
45+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
46+
}
4347

4448
def call_operator(
4549
self,
@@ -48,9 +52,10 @@ def call_operator(
4852
kwargs: Dict[str, object],
4953
meta: NodeMetadata,
5054
) -> ProxyValue:
51-
assert isinstance(
52-
op, EdgeOpOverload
53-
), "Op must be an EdgeOpOverload. Run this pass after to_edge()."
55+
if op in self.disallowed_targets:
56+
raise RuntimeError(
57+
f"Found unexpected aten op '{op}'. Make sure you run this pass after lowering to edge."
58+
)
5459

5560
if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
5661
return super().call_operator(

backends/cortex_m/test/test_replace_quant_nodes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@
99
from dataclasses import dataclass
1010
from typing import Optional
1111

12-
import executorch
1312
import executorch.backends.cortex_m.ops.operators # noqa
1413

14+
import executorch.exir
15+
1516
import torch
1617
from executorch.backends.cortex_m.passes.replace_quant_nodes_pass import (
1718
ReplaceQuantNodesPass,
1819
)
1920
from executorch.exir.dialects._ops import ops as exir_ops
21+
from executorch.exir.program._program import _transform
2022
from torch.export import export
2123
from torch.fx import GraphModule
2224
from torchao.quantization.pt2e.observer import HistogramObserver
@@ -128,11 +130,18 @@ def forward(self, x):
128130
# Step 1: Export and quantize the model
129131
exported_model = export(model.eval(), example_inputs, strict=True).module()
130132
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
133+
prepared_model(*example_inputs)
131134
quantized_model = convert_pt2e(prepared_model)
132135

133136
# Step 2: Export to EXIR
134137
exported = export(quantized_model, example_inputs, strict=True)
135138

139+
# The pass should raise an Exception if ran before to_edge.
140+
with self.assertRaisesRegex(
141+
Exception, "An error occurred when running the 'ReplaceQuantNodesPass' pass"
142+
):
143+
_transform(exported, ReplaceQuantNodesPass())
144+
136145
# Step 3: Convert to Edge
137146
edge_program = executorch.exir.to_edge(
138147
exported,

0 commit comments

Comments
 (0)