Skip to content

Commit 3652afd

Browse files
committed
Update on "migrate etrecord generation after to_edge_transform_and_lower to new infra"
Differential Revision: [D79420502](https://our.internmc.facebook.com/intern/diff/D79420502/) [ghstack-poisoned]
2 parents 24d8c68 + c082da0 commit 3652afd

29 files changed

+599
-544
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
namespace executorch::core_ml_backend_delegate {
4+
void register_backend_coreml();
5+
} // namespace executorch::core_ml_backend_delegate
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include "executorch_operations.h"
4+
#import <coreml_backend/delegate.h>
5+
#import "ETCoreMLStrings.h"
6+
#import "backend_delegate.h"
7+
8+
#import <executorch/runtime/core/evalue.h>
9+
#import <executorch/runtime/platform/log.h>
10+
#import <executorch/runtime/backend/interface.h>
11+
12+
#include <array>
13+
#import <memory>
14+
15+
namespace executorch::core_ml_backend_delegate {
16+
using executorch::runtime::get_backend_class;
17+
18+
static std::unique_ptr<executorch::backends::coreml::CoreMLBackendDelegate> backendInterfaceLazy_;
19+
20+
void register_backend_coreml() {
21+
auto backendInterface = executorch::runtime::get_backend_class(ETCoreMLStrings.delegateIdentifier.UTF8String);
22+
if (backendInterface == nullptr) {
23+
backendInterfaceLazy_ = std::make_unique<executorch::backends::coreml::CoreMLBackendDelegate>();
24+
executorch::runtime::Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, backendInterfaceLazy_.get()};
25+
std::ignore = register_backend(backend);
26+
}
27+
}
28+
29+
} // namespace executorch::core_ml_backend_delegate

backends/arm/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ endif()
1414

1515
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1616

17-
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
17+
set(_common_include_directories
18+
${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
19+
)
1820
add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS)
1921

2022

@@ -34,13 +36,12 @@ set(_arm_baremetal_sources backends/arm/runtime/EthosUBackend.cpp
3436
list(TRANSFORM _arm_baremetal_sources PREPEND "${EXECUTORCH_ROOT}/")
3537

3638
add_library(executorch_delegate_ethos_u STATIC ${_arm_baremetal_sources})
37-
target_include_directories(
38-
executorch_delegate_ethos_u PUBLIC ${_common_include_directories}
39-
)
40-
target_include_directories(
41-
executorch_delegate_ethos_u PUBLIC ${DRIVER_ETHOSU_INCLUDE_DIR}
39+
target_link_libraries(
40+
executorch_delegate_ethos_u PUBLIC executorch_core ethosu_core_driver
4241
)
4342

43+
install(TARGETS executorch_delegate_ethos_u EXPORT ExecuTorchTargets)
44+
4445
# end config for bare metal builds
4546
endif()
4647

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
1414

15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1818
from executorch.exir.pass_base import ExportPass, PassResult
@@ -62,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6262
}
6363
for partition in matmul_partitions:
6464
quantized_input = all(
65-
input_node.target in dq_ops for input_node in partition.input_nodes
65+
input_node.target in DQ_OPS for input_node in partition.input_nodes
6666
)
6767
matmul_node = [
6868
node for node in partition.nodes if node.target in matmul_targets
@@ -93,7 +93,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9393
graph_module.graph.erase_node(partition_input)
9494

9595
partition_output = list(partition.output_nodes[0].users)[0]
96-
quantized_output = partition_output.target in q_ops
96+
quantized_output = partition_output.target in Q_OPS
9797
if quantized_output:
9898
with graph_module.graph.inserting_after(matmul_node):
9999
# Create q-node after matmul

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1819

19-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
20+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
2021

2122
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -109,7 +110,7 @@ def fold_and_annotate_arg(
109110
return
110111

111112
arg_quant_params = None
112-
if arg.target in dq_ops:
113+
if arg.target in DQ_OPS:
113114
args = arg.args
114115
scales = args[1]
115116
if (
@@ -137,9 +138,9 @@ def fold_and_annotate_arg(
137138
if input_qparams is not None:
138139
node.meta["input_qparams"][i] = input_qparams
139140
for n in nodes_to_remove:
140-
if n.target not in dq_ops:
141+
if n.target not in DQ_OPS:
141142
raise RuntimeError(
142-
f"Expected one of {dq_ops} dq_op, got {n.target}"
143+
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
143144
)
144145

145146
node.replace_input_with(n, cast(Node, n.args[0]))
@@ -154,7 +155,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
154155
if n.op != "call_function":
155156
continue
156157
# Don't fold chains of quant-ops into each other.
157-
if n.target in (*q_ops, *dq_ops):
158+
if n.target in (*Q_OPS, *DQ_OPS):
158159
continue
159160

160161
# Make sure we haven't already set qparams meta information on the node
@@ -184,7 +185,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
184185
# Copy the users, since we are modifying it.
185186
users_copy = copy.copy(n.users)
186187
for i, user in enumerate(users_copy):
187-
if user.target not in q_ops:
188+
if user.target not in Q_OPS:
188189
continue
189190

190191
# quantization node found here, store the quantization parameters in meta value
@@ -221,7 +222,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
221222

222223
# Make sure we have a quantized operator
223224
user = list(n.users)[0]
224-
if user.target not in q_ops:
225+
if user.target not in Q_OPS:
225226
continue
226227

227228
qargs = QuantArgs.from_operator(user.target, user.args)

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs
9+
from executorch.backends.arm.constants import Q_OPS
10+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213
from torch.fx import Node
@@ -21,7 +22,7 @@ def _is_fuseable_quantized_activation(node: Node):
2122
min_val = node.args[1]
2223
is_fuseable = min_val == 0
2324

24-
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops
25+
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in Q_OPS
2526
if is_fuseable and is_quantized:
2627
quant_node = next(iter(node.users))
2728
quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
12+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import Tensor
1516
from torch.fx import GraphModule, Node
@@ -94,11 +95,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
9495
for node in graph_module.graph.nodes:
9596
node = cast(Node, node)
9697

97-
if node.target not in dq_ops:
98+
if node.target not in DQ_OPS:
9899
continue
99100
# Copy users since we remove them while iterating, modyfing the node.users list.
100101
for user in copy(node.users):
101-
if user.target in q_ops:
102+
if user.target in Q_OPS:
102103
self.fold_dq_q_to_rescale(node, user, graph_module)
103104
modified = True
104105
if len(node.users) == 0:

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_first_fake_tensor,
1313
insert_q_dq_pair,
1414
)
15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.pass_base import ExportPass, PassResult
1818
from torch.fx import Node
@@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5656
node.replace_input_with(input_node, unsqueeze_before)
5757

5858
# If Quantized we must insert unsqueeze --> q --> dq --> node
59-
if input_node.target in dq_ops:
59+
if input_node.target in DQ_OPS:
6060
q_params = input_node.args[1:]
6161
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
6262

@@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8989
user.replace_input_with(bmm_node, squeeze_after)
9090

9191
# If quantized, insert mm --> q --> dq --> squeeze
92-
if all(original_user.target in q_ops for original_user in original_users):
92+
if all(original_user.target in Q_OPS for original_user in original_users):
9393
q_params = original_users[0].args[1:]
9494
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
9595

backends/arm/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, cast, Final
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
12+
qd = exir_ops.edge.quantized_decomposed
13+
14+
QUANT_PER_TENSOR_OP: Final = qd.quantize_per_tensor.default
15+
QUANT_PER_TENSOR_OP_T: Final = qd.quantize_per_tensor.tensor
16+
QUANT_PER_CHANNEL_OP: Final = qd.quantize_per_channel.default
17+
18+
DEQUANT_PER_TENSOR_OP: Final = qd.dequantize_per_tensor.default
19+
DEQUANT_PER_TENSOR_OP_T: Final = qd.dequantize_per_tensor.tensor
20+
DEQUANT_PER_CHANNEL_OP: Final = qd.dequantize_per_channel.default
21+
22+
Q_OPS: Final = (QUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP_T, QUANT_PER_CHANNEL_OP)
23+
DQ_OPS: Final = (DEQUANT_PER_TENSOR_OP, DEQUANT_PER_TENSOR_OP_T, DEQUANT_PER_CHANNEL_OP)
24+
25+
PER_TENSOR_QDQ_OPS: Final = (
26+
QUANT_PER_TENSOR_OP,
27+
QUANT_PER_TENSOR_OP_T,
28+
DEQUANT_PER_TENSOR_OP,
29+
DEQUANT_PER_TENSOR_OP_T,
30+
)
31+
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2223
from executorch.backends.arm.operator_support.ethos_u55_support import (
2324
EthosU55DtypeSupport,
2425
EthosU55NotSupported,
2526
EthosU55TransposeCheck,
2627
EthosU55ViewCheck,
2728
)
28-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
2929
from executorch.backends.arm.tosa_specification import TosaSpecification
3030
from executorch.exir import ExportedProgram
3131
from executorch.exir.backend.utils import WhyNoPartitionReporter
@@ -368,7 +368,7 @@ def _is_matmul_node_supported(
368368
matched_partition = partition
369369
if matched_partition is not None:
370370
input_quantized = all(
371-
input_node.target in dq_ops
371+
input_node.target in DQ_OPS
372372
for input_node in matched_partition.input_nodes
373373
)
374374
if not input_quantized:
@@ -377,7 +377,7 @@ def _is_matmul_node_supported(
377377
)
378378
return False
379379
output_quantized = all(
380-
output_node_user.target in q_ops
380+
output_node_user.target in Q_OPS
381381
for output_node_user in matched_partition.output_nodes[0].users
382382
)
383383
if not output_quantized:
@@ -413,7 +413,7 @@ def is_node_supported(
413413
users = node.users
414414
output_quantized = all(
415415
user.target == operator.getitem
416-
and all(user_user.target in q_ops for user_user in user.users)
416+
and all(user_user.target in Q_OPS for user_user in user.users)
417417
for user in users
418418
)
419419
elif FuseQuantizedActivationPass._is_fuseable_input(node):
@@ -427,7 +427,7 @@ def is_node_supported(
427427
input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node)
428428

429429
input_quantized = input_quantized or all(
430-
(input_node.target in dq_ops)
430+
(input_node.target in DQ_OPS)
431431
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
432432
for input_node in node.all_input_nodes
433433
)
@@ -436,7 +436,7 @@ def is_node_supported(
436436
self.reporter.report_reject(node, "One or more inputs were not quantized.")
437437
return False
438438

439-
all_q_users = all((output_node.target in q_ops) for output_node in node.users)
439+
all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
440440
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
441441
output_quantized = output_quantized or all_q_users or not is_floating_point
442442

0 commit comments

Comments
 (0)