Skip to content

Commit bfd3272

Browse files
authored
Merge branch 'main' into main
2 parents 5eef730 + 643c381 commit bfd3272

File tree

38 files changed

+809
-159
lines changed

38 files changed

+809
-159
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -18,6 +18,9 @@
1818
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
1919
ConvertExpandCopyToRepeatPass,
2020
)
21+
from executorch.backends.arm._passes.convert_full_like_to_full_pass import (
22+
ConvertFullLikeToFullPass,
23+
)
2124
from executorch.backends.arm._passes.convert_split_to_slice import (
2225
ConvertSplitToSlicePass,
2326
)
@@ -49,6 +52,7 @@
4952
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
5053
FuseQuantizedActivationPass,
5154
)
55+
from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass
5256
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
5357
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
5458
KeepDimsFalseToSqueezePass,
@@ -72,6 +76,7 @@
7276
UnsqueezeScalarPlaceholdersPass,
7377
)
7478
from executorch.backends.arm.tosa_specification import TosaSpecification
79+
7580
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
7681
from executorch.exir import ExportedProgram
7782
from executorch.exir.pass_manager import PassManager
@@ -95,6 +100,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
95100
self.add_pass(ConvertMmToBmmPass())
96101
self.add_pass(DecomposeLinearPass())
97102
self.add_pass(ConvertMeanDimToAveragePoolPass())
103+
self.add_pass(ConvertFullLikeToFullPass())
98104

99105
self.add_pass(AnnotateDecomposedMatmulPass())
100106
self.add_pass(QuantizeOperatorArguments())
@@ -115,7 +121,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
115121
self.add_pass(ConvertSqueezesToViewPass())
116122

117123
self.add_pass(AnnotateChannelsLastDimOrder())
118-
124+
self.add_pass(InsertRescalePass())
119125
return self._transform(exported_program.graph_module)
120126

121127
def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
@@ -133,7 +139,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
133139
self.add_pass(ConvertMeanDimToAveragePoolPass())
134140
self.add_pass(DecomposeDivPass())
135141
self.add_pass(DecomposeSoftmaxesPass())
136-
142+
self.add_pass(ConvertFullLikeToFullPass())
137143
self.add_pass(AnnotateDecomposedMatmulPass())
138144
self.add_pass(QuantizeOperatorArguments())
139145
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
@@ -153,6 +159,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
153159
self.add_pass(ConvertSqueezesToViewPass())
154160

155161
self.add_pass(AnnotateChannelsLastDimOrder())
162+
self.add_pass(InsertRescalePass())
156163

157164
return self._transform(exported_program.graph_module)
158165

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from executorch.exir.dialects._ops import ops as exir_ops
7+
from executorch.exir.pass_base import ExportPass
8+
9+
10+
class ConvertFullLikeToFullPass(ExportPass):
11+
"""As per the full_like pytorch documentation,
12+
`torch.full_like(input, fill_value)` is equivalent to
13+
`torch.full(input.size(),
14+
fill_value,
15+
dtype=input.dtype,
16+
layout=input.layout,
17+
device=input.device
18+
)`
19+
Skip layout and device since it's not relevant for our backend.
20+
"""
21+
22+
def call_operator(self, op, args, kwargs, meta):
23+
if op not in [
24+
exir_ops.edge.aten.full_like.default,
25+
]:
26+
return super().call_operator(op, args, kwargs, meta)
27+
28+
tensor = args[0].data
29+
full_args = (list(tensor.shape), args[1])
30+
full_kwargs = {"dtype": tensor.dtype}
31+
return super().call_operator(
32+
exir_ops.edge.aten.full.default, full_args, full_kwargs, meta
33+
)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
131131
n = cast(Node, n)
132132
if n.op != "call_function":
133133
continue
134+
# Don't fold chains of quant-ops into each other.
135+
if n.target in (q_op, dq_op):
136+
continue
134137

135138
# Make sure we haven't already set qparams meta information on the node
136139
assert "input_qparams" not in n.meta.keys()
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
from copy import copy
8+
from typing import cast
9+
10+
import torch
11+
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, QuantArgs
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch import Tensor
15+
from torch.fx import GraphModule, Node
16+
from torch.library import custom_op, register_fake
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@custom_op("tosa::_rescale", mutates_args=()) # type: ignore[misc]
22+
def rescale(
23+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
24+
) -> Tensor:
25+
logger.warning(
26+
"Ran default implementation of tosa::_rescale."
27+
"This op is meant to always be inserted inside a partition and a correct default implementation is not implemented."
28+
)
29+
# Clone is needed to not return reference when rescaling to same dtype.
30+
# This is a neccessary requirement for non-mutating custom ops.
31+
return x.to(dtype=dtype).clone()
32+
33+
34+
@register_fake("tosa::_rescale") # type: ignore[misc]
35+
def rescale_fake(
36+
x: Tensor, dtype: torch.dtype, scale: float, in_zp: int, out_zp: int
37+
) -> Tensor:
38+
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
39+
Additionally validates TOSA constraints of a RESCALE op.
40+
"""
41+
if not (dtype == torch.int32 or dtype == torch.int8):
42+
raise NotImplementedError(
43+
"tosa::rescale currently only supports int32 and int8."
44+
)
45+
if dtype == torch.int32 and out_zp != 0:
46+
raise ValueError(
47+
"TOSA requires output_zp to be zero when the output dtype is int32."
48+
)
49+
if x.dtype == torch.int32 and in_zp != 0:
50+
raise ValueError(
51+
"TOSA requires input_zp to be zero when the input dtype is int32."
52+
)
53+
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
54+
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")
55+
if dtype == torch.int8 and not -128 <= out_zp <= 127:
56+
raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.")
57+
58+
return x.to(dtype=dtype).clone()
59+
60+
61+
class InsertRescalePass(ExportPass):
62+
"""Finds patterns of dq -> q, and replaces them
63+
with passthrough_to_tosa::rescales.
64+
65+
Does not garantuee that the dtypes and zero points are valid
66+
in TOSA, that is the job of the quantization annotator that
67+
produced the dq and q nodes. The TOSA constraints are validated
68+
in the fake implementation of passthrough_to_tosa:rescale.
69+
"""
70+
71+
def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule):
72+
dq_args = QuantArgs.from_operator(node.target, node.args)
73+
q_args = QuantArgs.from_operator(user.target, user.args)
74+
new_scale = dq_args.scale / q_args.scale
75+
76+
with graph_module.graph.inserting_before(node):
77+
rescale_node = create_node(
78+
graph_module.graph,
79+
torch.ops.tosa._rescale.default,
80+
(
81+
node.all_input_nodes[0],
82+
q_args.dtype,
83+
new_scale,
84+
dq_args.zp,
85+
q_args.zp,
86+
),
87+
)
88+
rescale_node.meta = copy(user.meta)
89+
user.replace_all_uses_with(rescale_node)
90+
graph_module.graph.erase_node(user)
91+
92+
def call(self, graph_module: GraphModule) -> PassResult:
93+
modified = False
94+
for node in graph_module.graph.nodes:
95+
node = cast(Node, node)
96+
97+
if node.target is not dq_op:
98+
continue
99+
# Copy users since we remove them while iterating, modyfing the node.users list.
100+
for user in copy(node.users):
101+
if user.target is q_op:
102+
self.fold_dq_q_to_rescale(node, user, graph_module)
103+
modified = True
104+
if len(node.users) == 0:
105+
graph_module.graph.erase_node(node)
106+
107+
graph_module = super().call(graph_module).graph_module
108+
graph_module.recompile()
109+
return PassResult(graph_module, modified)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
105105
exir_ops.edge.aten.linear.default,
106106
exir_ops.edge.aten.split_with_sizes_copy.default,
107107
exir_ops.edge.aten.full.default,
108+
exir_ops.edge.aten.full_like.default,
108109
exir_ops.edge.aten.ge.Tensor,
109110
exir_ops.edge.aten.gt.Tensor,
110111
exir_ops.edge.aten.le.Tensor,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
op_reciprocal,
3333
op_relu,
3434
op_repeat,
35+
op_rescale,
3536
op_rshift,
3637
op_rsqrt,
3738
op_sigmoid,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
from typing import cast, List
9+
10+
import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils
11+
import serializer.tosa_serializer as ts # type: ignore
12+
import torch
13+
14+
import tosa.Op as TosaOp # type: ignore
15+
from executorch.backends.arm.operators.node_visitor import (
16+
NodeVisitor,
17+
register_node_visitor,
18+
)
19+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
20+
from torch.fx import Node
21+
22+
23+
@register_node_visitor
24+
class RescaleVisitor(NodeVisitor):
25+
target = "_rescale.default"
26+
27+
def define_node(
28+
self,
29+
node: Node,
30+
tosa_graph: ts.TosaSerializer,
31+
inputs: List[TosaArg],
32+
output: TosaArg,
33+
) -> None:
34+
35+
input_dtype = inputs[0].dtype
36+
output_dtype = cast(torch.dtype, node.args[1])
37+
scale = cast(float, node.args[2])
38+
input_zp = cast(int, node.args[3])
39+
output_zp = cast(int, node.args[4])
40+
41+
# Skip int16 cases for now.
42+
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
43+
raise ValueError(
44+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
45+
)
46+
if output_dtype != torch.int8 and output_zp != 0:
47+
raise ValueError(
48+
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
49+
)
50+
51+
scale_width = 32 if output_dtype == torch.int32 else 16
52+
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
53+
scale, scale_width
54+
)
55+
attr_rescale = ts.TosaSerializerAttribute()
56+
attr_rescale.RescaleAttribute(
57+
input_zp=input_zp,
58+
output_zp=output_zp,
59+
multiplier=[multiplier],
60+
shift=[shift],
61+
scale32=output_dtype == torch.int32,
62+
double_round=False,
63+
per_channel=False,
64+
input_unsigned=False,
65+
output_unsigned=False,
66+
)
67+
68+
tosa_graph.addOperator(
69+
TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
70+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _match_pattern(
134134
torch.ops.aten.sum.dim_IntList,
135135
torch.ops.aten.hardsigmoid.default,
136136
torch.ops.aten.hardswish.default,
137+
torch.ops.aten.full_like.default,
137138
]
138139

139140
_one_to_one_shared_input_qspec = [
@@ -383,3 +384,11 @@ def annotate_graph( # type: ignore[return]
383384
_annotate_output(node, quant_properties.quant_output)
384385

385386
arm_quantizer_utils.mark_node_as_annotated(node) # type: ignore[attr-defined]
387+
388+
# Quantization does not allow kwargs for some reason.
389+
# Remove from ops we know have and where we know it does not break anything.
390+
if node.target in [
391+
torch.ops.aten.full_like.default,
392+
torch.ops.aten.full.default,
393+
]:
394+
node.kwargs = {}

0 commit comments

Comments
 (0)