Skip to content

Commit a89b963

Browse files
Arm backend: Move quant util functions to closer to usage (#13094)
The following functions are only used in quantization_annotator and can therefore be moved from arm_quantizer_utils.py to quantization_annotator.py: * is_large_scalar * is_non_float_tensor * get_node_target Additionally, is_ok_for_quantization is removed. It combined the is_large_scalar and is_non_float_tensor checks into one, which is now done directly where is_ok_for_quantization was used. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 325bbc0 commit a89b963

File tree

2 files changed

+76
-70
lines changed

2 files changed

+76
-70
lines changed

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 2 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# Utility functions for TOSAQuantizer
1212
#
1313

14-
from typing import cast, Sequence
14+
from typing import cast
1515

16-
import torch
17-
from torch._subclasses import FakeTensor
18-
from torch.fx import GraphModule, Node
16+
from torch.fx import Node
1917

2018
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
2119
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY
@@ -45,62 +43,3 @@ def mark_node_as_annotated(node: Node) -> None:
4543
if Q_ANNOTATION_KEY not in node.meta:
4644
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
4745
node.meta[Q_ANNOTATION_KEY]._annotated = True
48-
49-
50-
def is_ok_for_quantization(node: Node, gm: GraphModule):
51-
"""Check if an node can be quantized. The node can not be quantized if:
52-
- The node does not output a float tensor or,
53-
- The node outputs a large scalar.
54-
"""
55-
return not (is_non_float_tensor(node) or is_large_scalar(node, gm))
56-
57-
58-
def get_node_target(module: torch.nn.Module | GraphModule, target_str: str):
59-
targets = target_str.split(".")
60-
for target in targets[:-1]:
61-
module = module.get_submodule(target)
62-
return getattr(module, targets[-1])
63-
64-
65-
def is_large_scalar(node: Node, gm: GraphModule):
66-
"""Check if input is a large scalar value. So that we can skip quantization for the node
67-
since histc op (in HistogramObserver) only works for values up to certain upper bound
68-
"""
69-
if node.op == "get_attr" and isinstance(node.target, str):
70-
tensor = get_node_target(gm, node.target)
71-
# torch.histc works until this upper bound
72-
HISTC_UPPER_BOUND = 3.4028235e15
73-
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
74-
return False
75-
76-
77-
def is_non_float_tensor(node: Node) -> bool:
78-
"""Check if the output of a node has a data type other than `torch.float32`.
79-
80-
If the output is not `torch.float32`, quantization cannot be performed, as
81-
observers only work with floating-point tensors.
82-
83-
Args:
84-
node (Node): The node to check the output(s) for.
85-
86-
Returns:
87-
bool: `True` if the data type is not float32, otherwise `False`.
88-
89-
Note:
90-
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
91-
element is **not** an instance of `FakeTensor` or does **not** have
92-
`torch.float32` as its data type.
93-
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
94-
function returns True.
95-
"""
96-
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
97-
return any(
98-
not isinstance(fake_tensor, FakeTensor)
99-
or fake_tensor.dtype != torch.float32
100-
for fake_tensor in node.meta["val"]
101-
)
102-
103-
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
104-
return True
105-
106-
return node.meta["val"].dtype != torch.float32

backends/arm/quantizer/quantization_annotator.py

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import logging
77
import operator
88
from dataclasses import dataclass
9-
from typing import Callable, List, Optional
9+
from typing import Callable, List, Optional, Sequence
1010

1111
import torch
1212
import torch.fx
1313
import torch.nn.functional as F
1414
from executorch.backends.arm.quantizer import QuantizationConfig
1515
from executorch.backends.arm.tosa_utils import get_node_debug_info
16+
from torch._subclasses import FakeTensor
1617

1718
from torch.fx import Node
1819
from torchao.quantization.pt2e.quantizer import (
@@ -24,7 +25,6 @@
2425

2526
from .arm_quantizer_utils import (
2627
is_annotated,
27-
is_ok_for_quantization,
2828
is_output_annotated,
2929
mark_node_as_annotated,
3030
)
@@ -78,9 +78,16 @@ def _is_ok_for_quantization(
7878
"""
7979
# Check output
8080
if quant_properties.quant_output is not None:
81-
if not is_ok_for_quantization(node, gm): # type: ignore[attr-defined]
81+
if _is_non_float_tensor(node):
8282
logger.debug(
83-
f"Could not quantize node due to output: "
83+
"Could not quantize non float tensor for the following output node: "
84+
f"{get_node_debug_info(node, gm)}"
85+
)
86+
87+
return False
88+
elif _is_large_scalar(node, gm):
89+
logger.debug(
90+
"Could not quantize large scalar node for the following output node: "
8491
f"{get_node_debug_info(node, gm)}"
8592
)
8693

@@ -99,17 +106,77 @@ def _is_ok_for_quantization(
99106
raise TypeError(
100107
f"n_arg must be a Node instance, got {type(n_arg).__name__!r}"
101108
)
102-
if not is_ok_for_quantization(n_arg, gm): # type: ignore[attr-defined]
109+
110+
if _is_non_float_tensor(n_arg):
103111
logger.debug(
104-
f'could not quantize node due to input "{node}": '
105-
f"{get_node_debug_info(node, gm)}"
112+
"Could not quantize non float tensor for the following input "
113+
f"node: {get_node_debug_info(node, gm)}"
114+
)
115+
116+
return False
117+
elif _is_large_scalar(n_arg, gm):
118+
logger.debug(
119+
"Could not quantize large scalar node for the following input "
120+
f"node: {get_node_debug_info(node, gm)}"
106121
)
107122

108123
return False
109124

110125
return True
111126

112127

128+
def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str: str):
129+
targets = target_str.split(".")
130+
for target in targets[:-1]:
131+
module = module.get_submodule(target)
132+
return getattr(module, targets[-1])
133+
134+
135+
def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
136+
"""Check if input is a large scalar value. So that we can skip quantization for the
137+
node since histc op (in HistogramObserver) only works for values up to certain upper
138+
bound.
139+
"""
140+
if node.op == "get_attr" and isinstance(node.target, str):
141+
tensor = _get_node_target(gm, node.target)
142+
# torch.histc works until this upper bound
143+
HISTC_UPPER_BOUND = 3.4028235e15
144+
return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
145+
return False
146+
147+
148+
def _is_non_float_tensor(node: Node) -> bool:
149+
"""Check if the output of a node has a data type other than `torch.float32`.
150+
151+
If the output is not `torch.float32`, quantization cannot be performed, as
152+
observers only work with floating-point tensors.
153+
154+
Args:
155+
node (Node): The node to check the output(s) for.
156+
157+
Returns:
158+
bool: `True` if the data type is not float32, otherwise `False`.
159+
160+
Note:
161+
- If `node.meta["val"]` is a `list`, the function returns `True` if **any**
162+
element is **not** an instance of `FakeTensor` or does **not** have
163+
`torch.float32` as its data type.
164+
- If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
165+
function returns True.
166+
"""
167+
if "val" in node.meta and isinstance(node.meta["val"], Sequence):
168+
return any(
169+
not isinstance(fake_tensor, FakeTensor)
170+
or fake_tensor.dtype != torch.float32
171+
for fake_tensor in node.meta["val"]
172+
)
173+
174+
if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
175+
return True
176+
177+
return node.meta["val"].dtype != torch.float32
178+
179+
113180
def _annotate_input(node: Node, quant_property: _QuantProperty):
114181
if is_annotated(node):
115182
raise RuntimeError(

0 commit comments

Comments
 (0)