Skip to content

Commit 09e8a1f

Browse files
authored
Merge branch 'main' into export-D83810321
2 parents 733d54d + 0bfb61e commit 09e8a1f

File tree

106 files changed

+2011
-556
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

106 files changed

+2011
-556
lines changed

.ci/scripts/test_backend.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ fi
5959
if [[ "$FLOW" == *arm* ]]; then
6060
# Setup ARM deps.
6161
.ci/scripts/setup-arm-baremetal-tools.sh
62+
source examples/arm/ethos-u-scratch/setup_path.sh
6263

6364
if [[ "$FLOW" == *ethos_u* ]]; then
6465
# Prepare a test runner binary that can run on the Corstone-3x0 FVPs

.ci/scripts/test_model.sh

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,25 @@ prepare_artifacts_upload() {
4848
fi
4949
}
5050

51+
5152
build_cmake_executor_runner() {
5253
local backend_string_select="${1:-}"
5354
echo "Building executor_runner"
5455
rm -rf ${CMAKE_OUTPUT_DIR}
5556
mkdir ${CMAKE_OUTPUT_DIR}
57+
# Common options:
58+
COMMON="-DPYTHON_EXECUTABLE=$PYTHON_EXECUTABLE"
5659
if [[ "$backend_string_select" == "XNNPACK" ]]; then
5760
echo "Backend $backend_string_select selected"
58-
(cd ${CMAKE_OUTPUT_DIR} \
59-
&& cmake -DCMAKE_BUILD_TYPE=Release \
61+
cmake -DCMAKE_BUILD_TYPE=Release \
6062
-DEXECUTORCH_BUILD_XNNPACK=ON \
61-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
63+
${COMMON} \
64+
-B${CMAKE_OUTPUT_DIR} .
6265
cmake --build ${CMAKE_OUTPUT_DIR} -j4
6366
else
6467
cmake -DCMAKE_BUILD_TYPE=Debug \
6568
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
66-
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
69+
${COMMON} \
6770
-B${CMAKE_OUTPUT_DIR} .
6871
cmake --build ${CMAKE_OUTPUT_DIR} -j4 --config Debug
6972
fi

.ci/scripts/utils.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,14 +125,15 @@ build_executorch_runner_cmake() {
125125
clean_executorch_install_folders
126126
mkdir "${CMAKE_OUTPUT_DIR}"
127127

128-
pushd "${CMAKE_OUTPUT_DIR}" || return
129128
if [[ $1 == "Debug" ]]; then
130129
CXXFLAGS="-fsanitize=address,undefined"
131130
else
132131
CXXFLAGS=""
133132
fi
134-
CXXFLAGS="$CXXFLAGS" retry cmake -DPYTHON_EXECUTABLE="${PYTHON_EXECUTABLE}" -DCMAKE_BUILD_TYPE="${1:-Release}" ..
135-
popd || return
133+
CXXFLAGS="$CXXFLAGS" retry cmake \
134+
-DPYTHON_EXECUTABLE="${PYTHON_EXECUTABLE}" \
135+
-DCMAKE_BUILD_TYPE="${1:-Release}" \
136+
-B${CMAKE_OUTPUT_DIR} .
136137

137138
if [ "$(uname)" == "Darwin" ]; then
138139
CMAKE_JOBS=$(( $(sysctl -n hw.ncpu) - 1 ))

.github/workflows/test-backend-arm.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ on:
1212
paths:
1313
- .github/workflows/test-backend-arm.yml
1414
- .github/workflows/_test_backend.yml
15+
- .ci/scripts/test_backend.sh
16+
- backends/test/suite/flow.py
17+
- backends/test/suite/flows/arm.py
1518
workflow_dispatch:
1619

1720
concurrency:

CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,10 @@ if(EXECUTORCH_BUILD_EXECUTOR_RUNNER)
10211021
extension_runner_util gflags executorch_backends
10221022
)
10231023

1024+
if(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR)
1025+
list(APPEND _executor_runner_libs extension_flat_tensor)
1026+
endif()
1027+
10241028
if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED)
10251029
list(APPEND _executor_runner_libs optimized_native_cpu_ops_lib)
10261030
elseif(EXECUTORCH_BUILD_CADENCE)

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8282
InsertInt32CastsAfterInt64PlaceholdersPass,
8383
)
84-
from .insert_rescales_pass import InsertRescalePass # noqa
84+
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
8585
from .insert_table_ops import InsertTableOpsPass # noqa
8686
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
8787
from .match_arg_ranks_pass import MatchArgRanksPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
FuseEqualPlaceholdersPass,
8282
FuseQuantizedActivationPass,
8383
InsertInt32CastsAfterInt64PlaceholdersPass,
84+
InsertRescaleInt32Pass,
8485
InsertRescalePass,
8586
InsertTableOpsPass,
8687
MatchArgDtypePass,
@@ -214,6 +215,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
214215
self.add_pass(ToTosaMemoryFormatPass(exported_program))
215216
self.add_pass(RemoveNoopPass())
216217
self.add_pass(InsertRescalePass())
218+
self.add_pass(InsertRescaleInt32Pass())
217219

218220
self.validate_constraints_mandatory()
219221
return self._transform(exported_program.graph_module)

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 238 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from copy import copy
7-
from typing import cast, Set, Type
7+
from typing import cast, Dict, Optional, Set, Tuple, Type
88

9-
from executorch.backends.arm._passes.arm_pass_utils import create_node
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass import ArmPass
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node, set_node_arg
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_output_qparams,
14+
)
1015
from executorch.backends.arm._passes.quant_args import QuantArgs
1116
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1217
from executorch.exir.dialects._ops import ops as exir_ops
@@ -65,3 +70,234 @@ def call(self, graph_module: GraphModule) -> PassResult:
6570
graph_module = super().call(graph_module).graph_module
6671
graph_module.recompile()
6772
return PassResult(graph_module, modified)
73+
74+
75+
class InsertRescaleInt32Pass(ArmPass):
76+
"""
77+
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
78+
quantized implementations. This pass treats such operator nodes by
79+
inserting rescale ops before and after them if needed. Note that extra logic
80+
that handles the scales and zero points must be in place because the affected
81+
TOSA have naive implementations that do not account for the quantization
82+
parameters.
83+
"""
84+
85+
_passes_required_after: Set[Type[ExportPass]] = set()
86+
87+
included_targets = [
88+
exir_ops.edge.aten.abs.default,
89+
exir_ops.edge.aten.eq.Tensor,
90+
exir_ops.edge.aten.ge.Tensor,
91+
exir_ops.edge.aten.gt.Tensor,
92+
exir_ops.edge.aten.le.Tensor,
93+
exir_ops.edge.aten.lt.Tensor,
94+
exir_ops.edge.aten.maximum.default,
95+
exir_ops.edge.aten.minimum.default,
96+
]
97+
98+
def _int32_qargs(self, s):
99+
"""Helper creator function for INT32-based QuantArgs"""
100+
101+
return QuantArgs(
102+
scale=s,
103+
zp=0,
104+
qmin=torch.iinfo(torch.int32).min,
105+
qmax=torch.iinfo(torch.int32).max,
106+
dtype=torch.int32,
107+
)
108+
109+
def _get_inputs_rescaled_qparams(
110+
self, target, input_qparams: Dict[int, QuantArgs]
111+
) -> Dict[int, QuantArgs]:
112+
"""Get the qparams for the INT32 operands to the op ``target``
113+
114+
Inputs to the INT32-based operator must be rescaled from INT8 to INT32.
115+
This function computes the ``QuantArgs`` for each of the operands and returns
116+
it as a dict, mapping tensor index to ``QuantArgs``.
117+
"""
118+
119+
if target in [
120+
exir_ops.edge.aten.abs.default,
121+
exir_ops.edge.aten.eq.Tensor,
122+
exir_ops.edge.aten.ge.Tensor,
123+
exir_ops.edge.aten.gt.Tensor,
124+
exir_ops.edge.aten.le.Tensor,
125+
exir_ops.edge.aten.lt.Tensor,
126+
exir_ops.edge.aten.minimum.default,
127+
exir_ops.edge.aten.maximum.default,
128+
]:
129+
# For these ops, use the smallest scale among the INT8 operands.
130+
min_scale = min(
131+
[qp.get_scale_per_tensor() for qp in input_qparams.values()]
132+
)
133+
qparams = {
134+
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
135+
}
136+
else:
137+
raise ValueError(f"Not a valid target: {target}")
138+
139+
return qparams
140+
141+
def _get_output_qparams(
142+
self, target, inputs_qparams: Dict[int, QuantArgs]
143+
) -> Optional[QuantArgs]:
144+
"""Given an op ``target`` and the ``QuantArgs`` for each of its inputs, compute
145+
the scale of the output based on how the operator itself affects it."""
146+
147+
if target in [
148+
exir_ops.edge.aten.abs.default,
149+
exir_ops.edge.aten.maximum.default,
150+
exir_ops.edge.aten.minimum.default,
151+
]:
152+
# The op has not altered the scale; the output scale is equal to
153+
# the operands' scales.
154+
return self._int32_qargs(inputs_qparams[0].get_scale_per_tensor())
155+
elif target in [
156+
exir_ops.edge.aten.eq.Tensor,
157+
exir_ops.edge.aten.ge.Tensor,
158+
exir_ops.edge.aten.gt.Tensor,
159+
exir_ops.edge.aten.le.Tensor,
160+
exir_ops.edge.aten.lt.Tensor,
161+
]:
162+
# Output is bool for these ops and thus no qparams are present
163+
return None
164+
else:
165+
raise ValueError(f"Not a valid target: {target}")
166+
167+
def _get_rescale_qparams(
168+
self, target, input_qparams: Dict[int, QuantArgs]
169+
) -> Tuple[Dict[int, QuantArgs], Optional[QuantArgs]]:
170+
"""
171+
Get the quantization parameters of the INT32 inputs/outputs that will
172+
surround the node after the new RESCALE ops have been inserted.
173+
"""
174+
175+
inputs_rescaled_qparams = self._get_inputs_rescaled_qparams(
176+
target, input_qparams
177+
)
178+
output_qparams = self._get_output_qparams(target, inputs_rescaled_qparams)
179+
180+
return (inputs_rescaled_qparams, output_qparams)
181+
182+
def _rescale_inputs(self, graph, node, rescale_qargs: Dict[int, QuantArgs]) -> bool:
183+
qargs = node.meta["input_qparams"]
184+
185+
args_copy = list(node.args)
186+
seen_args = set()
187+
modified = False
188+
for i in qargs:
189+
qp = qargs[i]
190+
if qp.dtype != torch.int8:
191+
continue
192+
193+
arg_node = args_copy[i]
194+
if arg_node in seen_args:
195+
continue
196+
seen_args.add(arg_node)
197+
198+
with graph.inserting_after(arg_node):
199+
rescale_node = create_node(
200+
graph,
201+
exir_ops.backend.tosa.RESCALE.default,
202+
(
203+
arg_node,
204+
torch.int32,
205+
qp.get_scale_per_tensor()
206+
/ rescale_qargs[
207+
i
208+
].get_scale_per_tensor(), # Old scale / new scale
209+
qp.get_zp_per_tensor(), # Old zero point
210+
rescale_qargs[i].get_zp_per_tensor(), # New zero point
211+
),
212+
from_node=node,
213+
)
214+
215+
node.replace_input_with(arg_node, rescale_node)
216+
modified = True
217+
218+
return modified
219+
220+
def _rescale_outputs(self, graph, node, rescale_qargs: Optional[QuantArgs]) -> bool:
221+
if "output_qparams" not in node.meta or len(node.meta["output_qparams"]) == 0:
222+
return False
223+
224+
qargs = get_output_qparams(node)
225+
assert len(qargs) == 1
226+
assert rescale_qargs is not None
227+
228+
qarg = qargs[0]
229+
if qarg.dtype != torch.int8:
230+
return False
231+
232+
users_copy = list(node.users)
233+
234+
with graph.inserting_after(node):
235+
rescale_node = create_node(
236+
graph,
237+
exir_ops.backend.tosa.RESCALE.default,
238+
(
239+
node,
240+
torch.int8,
241+
rescale_qargs.get_scale_per_tensor()
242+
/ qarg.get_scale_per_tensor(), # Old scale / new scale
243+
rescale_qargs.get_zp_per_tensor(), # Old zero point
244+
qarg.get_zp_per_tensor(), # New zero point
245+
),
246+
from_node=node,
247+
)
248+
249+
for user in users_copy:
250+
user.replace_input_with(node, rescale_node)
251+
252+
return True
253+
254+
def call(self, graph_module: GraphModule) -> PassResult:
255+
graph = graph_module.graph
256+
257+
modified = False
258+
for node in list(graph.nodes):
259+
node = cast(Node, node)
260+
261+
if node.op != "call_function" or node.target not in self.included_targets:
262+
continue
263+
264+
if "input_qparams" not in node.meta or len(node.meta["input_qparams"]) == 0:
265+
continue
266+
input_qparams = node.meta["input_qparams"]
267+
268+
inputs_rescale_qargs, output_rescale_qargs = self._get_rescale_qparams(
269+
node.target, input_qparams
270+
)
271+
272+
inputs_was_rescaled = self._rescale_inputs(
273+
graph, node, inputs_rescale_qargs
274+
)
275+
outputs_was_rescaled = False
276+
if inputs_was_rescaled:
277+
outputs_was_rescaled = self._rescale_outputs(
278+
graph, node, output_rescale_qargs
279+
)
280+
modified = True
281+
282+
# Update node metadata
283+
284+
if inputs_was_rescaled:
285+
assert len(inputs_rescale_qargs) == len(node.meta["input_qparams"])
286+
node.meta["input_qparams"] = inputs_rescale_qargs
287+
288+
if outputs_was_rescaled:
289+
assert len(node.meta["output_qparams"]) == 1
290+
node.meta["output_qparams"] = {0: output_rescale_qargs}
291+
292+
# If the output type is specified in the node, change it such
293+
# that it matches the subsequent rescale node(s) that this node
294+
# now has output edges to.
295+
if "dtype" in node.kwargs:
296+
set_node_arg(node, "dtype", torch.int32)
297+
298+
if modified:
299+
# Retrace the graph to update the fake tensor types
300+
graph_module = super().call(graph_module).graph_module
301+
graph_module.recompile()
302+
303+
return PassResult(graph_module, modified)

0 commit comments

Comments
 (0)