Skip to content

Commit 6fbf0d4

Browse files
authored
Merge branch 'main' into gh/gasoonjia/33/orig
2 parents 995ffca + 9d86cbe commit 6fbf0d4

25 files changed

+559
-539
lines changed

backends/arm/CMakeLists.txt

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ endif()
1414

1515
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
1616

17-
set(_common_include_directories ${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
17+
set(_common_include_directories
18+
${EXECUTORCH_ROOT}/.. ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
19+
)
1820
add_compile_definitions(C10_USING_CUSTOM_GENERATED_MACROS)
1921

2022

@@ -34,13 +36,12 @@ set(_arm_baremetal_sources backends/arm/runtime/EthosUBackend.cpp
3436
list(TRANSFORM _arm_baremetal_sources PREPEND "${EXECUTORCH_ROOT}/")
3537

3638
add_library(executorch_delegate_ethos_u STATIC ${_arm_baremetal_sources})
37-
target_include_directories(
38-
executorch_delegate_ethos_u PUBLIC ${_common_include_directories}
39-
)
40-
target_include_directories(
41-
executorch_delegate_ethos_u PUBLIC ${DRIVER_ETHOSU_INCLUDE_DIR}
39+
target_link_libraries(
40+
executorch_delegate_ethos_u PUBLIC executorch_core ethosu_core_driver
4241
)
4342

43+
install(TARGETS executorch_delegate_ethos_u EXPORT ExecuTorchTargets)
44+
4445
# end config for bare metal builds
4546
endif()
4647

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from executorch.backends.arm._passes.arm_pass_utils import create_node
1414

15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1818
from executorch.exir.pass_base import ExportPass, PassResult
@@ -62,7 +62,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
6262
}
6363
for partition in matmul_partitions:
6464
quantized_input = all(
65-
input_node.target in dq_ops for input_node in partition.input_nodes
65+
input_node.target in DQ_OPS for input_node in partition.input_nodes
6666
)
6767
matmul_node = [
6868
node for node in partition.nodes if node.target in matmul_targets
@@ -93,7 +93,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
9393
graph_module.graph.erase_node(partition_input)
9494

9595
partition_output = list(partition.output_nodes[0].users)[0]
96-
quantized_output = partition_output.target in q_ops
96+
quantized_output = partition_output.target in Q_OPS
9797
if quantized_output:
9898
with graph_module.graph.inserting_after(matmul_node):
9999
# Create q-node after matmul

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1819

19-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
20+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
2021

2122
from executorch.exir.dialects._ops import ops as exir_ops
2223
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -109,7 +110,7 @@ def fold_and_annotate_arg(
109110
return
110111

111112
arg_quant_params = None
112-
if arg.target in dq_ops:
113+
if arg.target in DQ_OPS:
113114
args = arg.args
114115
scales = args[1]
115116
if (
@@ -137,9 +138,9 @@ def fold_and_annotate_arg(
137138
if input_qparams is not None:
138139
node.meta["input_qparams"][i] = input_qparams
139140
for n in nodes_to_remove:
140-
if n.target not in dq_ops:
141+
if n.target not in DQ_OPS:
141142
raise RuntimeError(
142-
f"Expected one of {dq_ops} dq_op, got {n.target}"
143+
f"Expected one of {DQ_OPS} dq_op, got {n.target}"
143144
)
144145

145146
node.replace_input_with(n, cast(Node, n.args[0]))
@@ -154,7 +155,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
154155
if n.op != "call_function":
155156
continue
156157
# Don't fold chains of quant-ops into each other.
157-
if n.target in (*q_ops, *dq_ops):
158+
if n.target in (*Q_OPS, *DQ_OPS):
158159
continue
159160

160161
# Make sure we haven't already set qparams meta information on the node
@@ -184,7 +185,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
184185
# Copy the users, since we are modifying it.
185186
users_copy = copy.copy(n.users)
186187
for i, user in enumerate(users_copy):
187-
if user.target not in q_ops:
188+
if user.target not in Q_OPS:
188189
continue
189190

190191
# quantization node found here, store the quantization parameters in meta value
@@ -221,7 +222,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
221222

222223
# Make sure we have a quantized operator
223224
user = list(n.users)[0]
224-
if user.target not in q_ops:
225+
if user.target not in Q_OPS:
225226
continue
226227

227228
qargs = QuantArgs.from_operator(user.target, user.args)

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import q_ops, QuantArgs
9+
from executorch.backends.arm.constants import Q_OPS
10+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1011
from executorch.exir.dialects._ops import ops as exir_ops
1112
from executorch.exir.pass_base import ExportPass, PassResult
1213
from torch.fx import Node
@@ -21,7 +22,7 @@ def _is_fuseable_quantized_activation(node: Node):
2122
min_val = node.args[1]
2223
is_fuseable = min_val == 0
2324

24-
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in q_ops
25+
is_quantized = len(node.users) == 1 and next(iter(node.users)).target in Q_OPS
2526
if is_fuseable and is_quantized:
2627
quant_node = next(iter(node.users))
2728
quant_args = QuantArgs.from_operator(quant_node.target, quant_node.args)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops, QuantArgs
12+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13+
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import Tensor
1516
from torch.fx import GraphModule, Node
@@ -94,11 +95,11 @@ def call(self, graph_module: GraphModule) -> PassResult:
9495
for node in graph_module.graph.nodes:
9596
node = cast(Node, node)
9697

97-
if node.target not in dq_ops:
98+
if node.target not in DQ_OPS:
9899
continue
99100
# Copy users since we remove them while iterating, modyfing the node.users list.
100101
for user in copy(node.users):
101-
if user.target in q_ops:
102+
if user.target in Q_OPS:
102103
self.fold_dq_q_to_rescale(node, user, graph_module)
103104
modified = True
104105
if len(node.users) == 0:

backends/arm/_passes/mm_to_bmm_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
get_first_fake_tensor,
1313
insert_q_dq_pair,
1414
)
15-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
15+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1616
from executorch.exir.dialects._ops import ops as exir_ops
1717
from executorch.exir.pass_base import ExportPass, PassResult
1818
from torch.fx import Node
@@ -56,7 +56,7 @@ def call(self, graph_module: torch.fx.GraphModule):
5656
node.replace_input_with(input_node, unsqueeze_before)
5757

5858
# If Quantized we must insert unsqueeze --> q --> dq --> node
59-
if input_node.target in dq_ops:
59+
if input_node.target in DQ_OPS:
6060
q_params = input_node.args[1:]
6161
insert_q_dq_pair(graph, unsqueeze_before, q_params, from_node=node)
6262

@@ -89,7 +89,7 @@ def call(self, graph_module: torch.fx.GraphModule):
8989
user.replace_input_with(bmm_node, squeeze_after)
9090

9191
# If quantized, insert mm --> q --> dq --> squeeze
92-
if all(original_user.target in q_ops for original_user in original_users):
92+
if all(original_user.target in Q_OPS for original_user in original_users):
9393
q_params = original_users[0].args[1:]
9494
insert_q_dq_pair(graph, bmm_node, q_params, from_node=node)
9595

backends/arm/constants.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Any, cast, Final
7+
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
12+
qd = exir_ops.edge.quantized_decomposed
13+
14+
QUANT_PER_TENSOR_OP: Final = qd.quantize_per_tensor.default
15+
QUANT_PER_TENSOR_OP_T: Final = qd.quantize_per_tensor.tensor
16+
QUANT_PER_CHANNEL_OP: Final = qd.quantize_per_channel.default
17+
18+
DEQUANT_PER_TENSOR_OP: Final = qd.dequantize_per_tensor.default
19+
DEQUANT_PER_TENSOR_OP_T: Final = qd.dequantize_per_tensor.tensor
20+
DEQUANT_PER_CHANNEL_OP: Final = qd.dequantize_per_channel.default
21+
22+
Q_OPS: Final = (QUANT_PER_TENSOR_OP, QUANT_PER_TENSOR_OP_T, QUANT_PER_CHANNEL_OP)
23+
DQ_OPS: Final = (DEQUANT_PER_TENSOR_OP, DEQUANT_PER_TENSOR_OP_T, DEQUANT_PER_CHANNEL_OP)
24+
25+
PER_TENSOR_QDQ_OPS: Final = (
26+
QUANT_PER_TENSOR_OP,
27+
QUANT_PER_TENSOR_OP_T,
28+
DEQUANT_PER_TENSOR_OP,
29+
DEQUANT_PER_TENSOR_OP_T,
30+
)
31+
PER_CHANNEL_QDQ_OPS: Final = (QUANT_PER_CHANNEL_OP, DEQUANT_PER_CHANNEL_OP)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
FuseQuantizedActivationPass,
2020
)
2121
from executorch.backends.arm._passes.insert_table_ops import TableOps
22+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2223
from executorch.backends.arm.operator_support.ethos_u55_support import (
2324
EthosU55DtypeSupport,
2425
EthosU55NotSupported,
2526
EthosU55TransposeCheck,
2627
EthosU55ViewCheck,
2728
)
28-
from executorch.backends.arm.tosa_quant_utils import dq_ops, q_ops
2929
from executorch.backends.arm.tosa_specification import TosaSpecification
3030
from executorch.exir import ExportedProgram
3131
from executorch.exir.backend.utils import WhyNoPartitionReporter
@@ -368,7 +368,7 @@ def _is_matmul_node_supported(
368368
matched_partition = partition
369369
if matched_partition is not None:
370370
input_quantized = all(
371-
input_node.target in dq_ops
371+
input_node.target in DQ_OPS
372372
for input_node in matched_partition.input_nodes
373373
)
374374
if not input_quantized:
@@ -377,7 +377,7 @@ def _is_matmul_node_supported(
377377
)
378378
return False
379379
output_quantized = all(
380-
output_node_user.target in q_ops
380+
output_node_user.target in Q_OPS
381381
for output_node_user in matched_partition.output_nodes[0].users
382382
)
383383
if not output_quantized:
@@ -413,7 +413,7 @@ def is_node_supported(
413413
users = node.users
414414
output_quantized = all(
415415
user.target == operator.getitem
416-
and all(user_user.target in q_ops for user_user in user.users)
416+
and all(user_user.target in Q_OPS for user_user in user.users)
417417
for user in users
418418
)
419419
elif FuseQuantizedActivationPass._is_fuseable_input(node):
@@ -427,7 +427,7 @@ def is_node_supported(
427427
input_quantized = FuseQuantizedActivationPass._is_fuseable_input(input_node)
428428

429429
input_quantized = input_quantized or all(
430-
(input_node.target in dq_ops)
430+
(input_node.target in DQ_OPS)
431431
or (not get_first_fake_tensor(input_node).dtype.is_floating_point)
432432
for input_node in node.all_input_nodes
433433
)
@@ -436,7 +436,7 @@ def is_node_supported(
436436
self.reporter.report_reject(node, "One or more inputs were not quantized.")
437437
return False
438438

439-
all_q_users = all((output_node.target in q_ops) for output_node in node.users)
439+
all_q_users = all((output_node.target in Q_OPS) for output_node in node.users)
440440
is_floating_point = get_first_fake_tensor(node).dtype.is_floating_point
441441
output_quantized = output_quantized or all_q_users or not is_floating_point
442442

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# Utility functions for TOSAQuantizer
1212
#
1313

14-
from typing import cast, Sequence
14+
from typing import cast
1515

16-
import torch
17-
from torch._subclasses import FakeTensor
18-
from torch.fx import GraphModule, Node
16+
from torch.fx import Node
1917

2018
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
2119
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
@@ -45,62 +43,3 @@ def mark_node_as_annotated(node: Node) -> None:
4543
if Q_ANNOTATION_KEY not in node.meta:
4644
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
4745
node.meta[Q_ANNOTATION_KEY]._annotated = True
48-
49-
50-
def is_ok_for_quantization(node: Node, gm: GraphModule):
51-
"""Check if an node can be quantized. The node can not be quantized if:
52-
- The node does not output a float tensor or,
53-
- The node outputs a large scalar.
54-
"""
55-
return not (is_non_float_tensor(node) or is_large_scalar(node, gm))
56-
57-
58-
def get_node_target(module: torch.nn.Module | GraphModule, target_str: str):
59-
targets = target_str.split(".")
60-
for target in targets[:-1]:
61-
module = module.get_submodule(target)
62-
return getattr(module, targets[-1])
63-
64-
65-
def is_large_scalar(node: Node, gm: GraphModule):
66-
"""Check if input is a large scalar value. So that we can skip quantization for the node
67-
since histc op (in HistogramObserver) only works for values up to certain upper bound
68-
"""
69-
if node.op == "get_attr" and isinstance(node.target, str):
70-
tensor = get_node_target(gm, node.target)
71-
# torch.histc works until this upper bound
72-
HISTC_UPPER_BOUND = 3.4028235e15
73-
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
74-
return False
75-
76-
77-
def is_non_float_tensor(node: Node) -> bool:
78-
"""Check if the output of a node has a data type other than `torch.float32`.
79-
80-
If the output is not `torch.float32`, quantization cannot be performed, as
81-
observers only work with floating-point tensors.
82-
83-
Args:
84-
node (Node): The node to check the output(s) for.
85-
86-
Returns:
87-
bool: `True` if the data type is not float32, otherwise `False`.
88-
89-
Note:
90-
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
91-
element is **not** an instance of `FakeTensor` or does **not** have
92-
`torch.float32` as its data type.
93-
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
94-
function returns True.
95-
"""
96-
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
97-
return any(
98-
not isinstance(fake_tensor, FakeTensor)
99-
or fake_tensor.dtype != torch.float32
100-
for fake_tensor in node.meta["val"]
101-
)
102-
103-
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
104-
return True
105-
106-
return node.meta["val"].dtype != torch.float32

0 commit comments

Comments
 (0)