Skip to content

Commit 141d18c

Browse files
authored
Merge branch 'main' into minimax_docs
2 parents 2a43b48 + 11ff521 commit 141d18c

File tree

113 files changed

+1551
-812
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+1551
-812
lines changed

backends/arm/_passes/decompose_meandim_pass.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass
1414
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
1515
from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass
16+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1617
from executorch.exir.backend.utils import WhyNoPartitionReporter
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051
raise RuntimeError(f"Can't get meandim decomposition for op {op}")
5152

5253

54+
def get_quantization(op):
55+
"""Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+
if op in DQ_OPS:
57+
# Input of op can be placeholder, can't use that to get quant node directly.
58+
quant_type_index = DQ_OPS.index(op)
59+
return Q_OPS[quant_type_index], op
60+
return None
61+
62+
5363
class DecomposeMeanDimPass(ArmPass):
5464
"""
5565
Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131
dims_to_reduce = [dim - 1 for dim in dims_to_reduce]
122132

123133
x = super().call_operator(view_op, (x, new_shape), {}, meta, True)
134+
x = self._maybe_insert_q_dq_after(x, meta)
124135

125136
# Reduce (h,w) dims by avg pool if possible
126137
x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta)
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144
dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce]
134145

135146
x = super().call_operator(view_op, (x, temp_shape), {}, meta, True)
136-
147+
x = self._maybe_insert_q_dq_after(x, meta)
137148
# Reduce remaining dims by sum
138149
x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype)
139150

@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167
full = super().call_operator(
157168
full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True
158169
)
170+
if (quant_ops := get_quantization(input_node.node.target)) is not None:
171+
# Insert Q and DQ nodes after full op.
172+
# Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+
q_op, dq_op = quant_ops
174+
qmax = input_node.node.args[4]
175+
full_quant_args = (
176+
1 / (N * qmax), # Scale to map qmax to 1/N
177+
0, # Zero point
178+
*input_node.node.args[3:],
179+
)
180+
q_args = (full, *full_quant_args)
181+
full = super().call_operator(
182+
q_op,
183+
q_args,
184+
kwargs={},
185+
meta=meta,
186+
updated=True,
187+
)
188+
dq_args = (full, *full_quant_args)
189+
full = super().call_operator(
190+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
191+
)
192+
193+
# Insert Q and DQ nodes after sum op.
194+
# Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+
sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:])
196+
q_args = (sum, *sum_quant_args)
197+
sum = super().call_operator(
198+
q_op,
199+
q_args,
200+
kwargs={},
201+
meta=meta,
202+
updated=True,
203+
)
204+
dq_args = (sum, *sum_quant_args)
205+
sum = super().call_operator(
206+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
207+
)
208+
159209
return super().call_operator(mul_op, (sum, full), {}, meta, True)
160210

161211
def _reduce_by_average_pool(self, op, input_node, dims, meta):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240
)
191241

192242
if is_supported:
243+
out = super().call_operator(avgpool_op, args, {}, meta, True)
244+
out = self._maybe_insert_q_dq_after(out, meta)
193245
return (
194-
super().call_operator(avgpool_op, args, {}, meta, True),
246+
out,
195247
dims_to_reduce_by_sum,
196248
)
197249

198250
else:
199251
return input_node, dims
252+
253+
def _maybe_insert_q_dq_after(self, op, meta):
254+
"""If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+
if len(op.node.all_input_nodes) > 1:
257+
raise ValueError(
258+
f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}"
259+
)
260+
input_node = op.node.all_input_nodes[0]
261+
if (quant_ops := get_quantization(input_node.target)) is not None:
262+
q_op, dq_op = quant_ops
263+
quant_args = list(input_node.args[1:])
264+
q_args = (op, *quant_args)
265+
out = super().call_operator(
266+
q_op,
267+
q_args,
268+
kwargs={},
269+
meta=meta,
270+
updated=True,
271+
)
272+
dq_args = (out, *quant_args)
273+
return super().call_operator(
274+
dq_op, dq_args, kwargs={}, meta=meta, updated=True
275+
)
276+
else:
277+
return op

backends/arm/_passes/fuse_constant_ops_pass.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def resolve_arg(arg):
6565
if isinstance(arg, torch.fx.Node) and arg in input_nodes:
6666
idx = input_nodes.index(arg)
6767
t = get_param_tensor(self.exported_program, arg)
68-
if qparams:
68+
# Check if qparams exist for this arg
69+
if qparams and idx in qparams.keys():
6970
t = qparams[idx].dequantize_value(t)
7071
return t
7172
if isinstance(arg, tuple):

backends/arm/quantizer/quantization_annotator.py

Lines changed: 126 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5+
"""Provide quantization annotation logic for Arm backends.
6+
7+
This module computes per-node quantization properties and applies input/output
8+
annotations to FX graphs using TorchAO qspecs.
9+
10+
"""
511

612
import logging
713
import operator
@@ -44,12 +50,31 @@ class _QuantProperty:
4450

4551

4652
class _OpQuantProperties:
53+
"""Collect input/output quantization properties for a node.
54+
55+
Attributes:
56+
quant_inputs (List[_QuantProperty]): Quantization specs for inputs
57+
indexed by argument positions.
58+
quant_output (Optional[_QuantProperty]): Quantization spec for the
59+
node's output when applicable.
60+
61+
"""
62+
4763
def __init__(self):
4864
self.quant_inputs: List[_QuantProperty] = []
4965
self.quant_output: Optional[_QuantProperty] = None
5066

5167

5268
def _as_list(x):
69+
"""Return ``x`` wrapped as a list if needed.
70+
71+
Args:
72+
x: Value or list of values.
73+
74+
Returns:
75+
list: ``x`` if already a list; otherwise ``[x]``.
76+
77+
"""
5378
if isinstance(x, list):
5479
return x
5580
else:
@@ -66,16 +91,19 @@ def _is_ok_for_quantization(
6691
A node can be quantized if:
6792
- All inputs that are required for quantization are of type `float32`
6893
and are not large scalar values.
69-
- The output of the node itself is of type `float32` and is not a large scalar.
94+
- The output of the node itself is of type `float32` and is not a large
95+
scalar.
7096
7197
Args:
7298
node (Node): The node being analyzed.
73-
quant_properties (_OpQuantProperties): Contains quantization properties for
74-
the node, including input and output quantization specifications.
75-
gm (torch.fx.GraphModule): The graph module containing the computational graph.
99+
quant_properties (_OpQuantProperties): Contains quantization properties
100+
for the node, including input and output quantization specifications.
101+
gm (torch.fx.GraphModule): The graph module containing the computational
102+
graph.
76103
77104
Returns:
78105
bool: `True` if the node can be quantized, otherwise `False`.
106+
79107
"""
80108
# Check output
81109
if quant_properties.quant_output is not None:
@@ -127,16 +155,28 @@ def _is_ok_for_quantization(
127155

128156

129157
def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str):
158+
"""Get an attribute from a module by dotted path.
159+
160+
Args:
161+
module (torch.nn.Module | torch.fx.GraphModule): Root module.
162+
target_str (str): Dotted attribute path, e.g., ``"sub.weight"``.
163+
164+
Returns:
165+
Any: Resolved attribute on the module.
166+
167+
"""
130168
targets = target_str.split(".")
131169
for target in targets[:-1]:
132170
module = module.get_submodule(target)
133171
return getattr(module, targets[-1])
134172

135173

136174
def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
137-
"""Check if input is a large scalar value. So that we can skip quantization for the
138-
node since histc op (in HistogramObserver) only works for values up to certain upper
139-
bound.
175+
"""Return True if input is a large scalar value.
176+
177+
Large scalars are skipped because ``torch.histc`` supports values only up
178+
to a certain upper bound.
179+
140180
"""
141181
HISTC_UPPER_BOUND = 3.4028235e15
142182
if node.op == "get_attr" and isinstance(node.target, str):
@@ -166,11 +206,12 @@ def _is_non_float_tensor(node: Node) -> bool:
166206
bool: `True` if the data type is not float32, otherwise `False`.
167207
168208
Note:
169-
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
170-
element is **not** an instance of `FakeTensor` or does **not** have
209+
- If `node.meta["val"]` is a `list`, the function returns `True` if
210+
any element is not an instance of `FakeTensor` or does not have
171211
`torch.float32` as its data type.
172-
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
173-
function returns True.
212+
- If node.meta["val"] is missing or is not an instance of `FakeTensor`,
213+
the function returns True.
214+
174215
"""
175216
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
176217
return any(
@@ -186,6 +227,20 @@ def _is_non_float_tensor(node: Node) -> bool:
186227

187228

188229
def _annotate_input(node: Node, quant_property: _QuantProperty):
230+
"""Annotate a node's input with the given qspec.
231+
232+
Maps the specified input argument(s) to the provided quantization spec and
233+
optionally marks the input node(s) as annotated.
234+
235+
Args:
236+
node (Node): Node whose input should be annotated.
237+
quant_property (_QuantProperty): Input index and qspec(s).
238+
239+
Raises:
240+
RuntimeError: If the node is already annotated.
241+
TypeError: If an input argument is not a ``Node`` instance.
242+
243+
"""
189244
if is_annotated(node):
190245
raise RuntimeError(
191246
f"Cannot annotate input: node '{node.name}' is already annotated"
@@ -211,6 +266,18 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
211266

212267

213268
def _annotate_output(node: Node, quant_property: _QuantProperty):
269+
"""Annotate a node's output with the given qspec.
270+
271+
Args:
272+
node (Node): Node whose output should be annotated.
273+
quant_property (_QuantProperty): Output index and qspec.
274+
275+
Raises:
276+
RuntimeError: If the node is already annotated.
277+
ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
278+
``index`` is not zero.
279+
280+
"""
214281
if is_annotated(node):
215282
raise RuntimeError(
216283
f"Cannot annotate output: node '{node.name}' is already annotated"
@@ -230,12 +297,13 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
230297
def _match_pattern(
231298
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
232299
) -> bool:
233-
"""
234-
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
235-
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
236-
chain pass the filtering.
300+
"""Check whether a node chain matches a pattern.
301+
302+
Verify a chain of ancestors -> node -> descendants matches the provided
303+
``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
304+
to pass the filter. Each pattern element is a list of disjunctive node
305+
targets.
237306
238-
Each 'pattern' element is composed of a list of disjunctive nodes types.
239307
"""
240308
if len(pattern) < 1:
241309
raise ValueError("No pattern provided")
@@ -382,6 +450,21 @@ def _match_pattern(
382450
def get_quant_properties( # noqa: C901
383451
node: Node, gm: torch.fx.GraphModule, quantization_config
384452
) -> _OpQuantProperties | None:
453+
"""Compute quantization properties for a node.
454+
455+
Determine which inputs and/or outputs should be annotated for quantization
456+
based on the node's operator and surrounding pattern.
457+
458+
Args:
459+
node (Node): Node to analyze.
460+
gm (torch.fx.GraphModule): Owning graph module.
461+
quantization_config: Source for activation/weight/bias qspecs.
462+
463+
Returns:
464+
_OpQuantProperties | None: Properties to apply, or ``None`` if the
465+
node is unsupported or not suitable for quantization.
466+
467+
"""
385468
input_act_qspec = quantization_config.get_input_act_qspec()
386469
weight_qspec = quantization_config.get_weight_qspec()
387470
output_act_qspec = quantization_config.get_output_act_qspec()
@@ -390,6 +473,7 @@ def get_quant_properties( # noqa: C901
390473
quant_properties = _OpQuantProperties()
391474

392475
def any_or_hardtanh_min_zero(n: Node):
476+
"""Return True for any op or hardtanh with ``min_val == 0``."""
393477
# Check that if the node is a hardtanh, its min_val is zero
394478
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
395479

@@ -524,12 +608,19 @@ def any_or_hardtanh_min_zero(n: Node):
524608
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
525609
elif node.target in (torch.ops.aten.where.self,):
526610
true_node = ensure_type(Node, node.args[1])
527-
shared_qspec = SharedQuantizationSpec(true_node)
611+
input_qspec = (
612+
SharedQuantizationSpec(true_node)
613+
if is_output_annotated(true_node)
614+
else input_act_qspec
615+
)
528616
quant_properties.quant_inputs = [
529-
_QuantProperty(1, shared_qspec),
530-
_QuantProperty(2, shared_qspec),
617+
_QuantProperty(1, input_qspec),
618+
_QuantProperty(2, SharedQuantizationSpec((true_node, node))),
531619
]
532-
quant_properties.quant_output = _QuantProperty(0, shared_qspec)
620+
quant_properties.quant_output = _QuantProperty(
621+
0,
622+
SharedQuantizationSpec((true_node, node)),
623+
)
533624
elif node.target in _one_to_one_shared_input_or_input_act_qspec:
534625
input_node = ensure_type(Node, node.args[0])
535626
input_qspec = (
@@ -636,6 +727,21 @@ def annotate_graph( # type: ignore[return]
636727
quantization_config: QuantizationConfig,
637728
filter_fn: Optional[Callable[[Node], bool]] = None,
638729
) -> Optional[List[List[Node]]]:
730+
"""Annotate supported nodes in a graph with quantization specs.
731+
732+
Iterate through call_function nodes, computes quantization properties, and
733+
apply input/output annotations. A filter can restrict which nodes are
734+
considered.
735+
736+
Args:
737+
gm (torch.fx.GraphModule): Graph to annotate.
738+
quantization_config (QuantizationConfig): Default qspecs for nodes.
739+
filter_fn (Optional[Callable[[Node], bool]]): Optional node predicate.
740+
741+
Returns:
742+
Optional[List[List[Node]]]: Reserved for future use; currently None.
743+
744+
"""
639745
for node in gm.graph.nodes:
640746
if node.op != "call_function":
641747
continue

0 commit comments

Comments
 (0)