6
6
import logging
7
7
import operator
8
8
from dataclasses import dataclass
9
- from typing import Callable , List , Optional
9
+ from typing import Callable , List , Optional , Sequence
10
10
11
11
import torch
12
12
import torch .fx
13
13
import torch .nn .functional as F
14
14
from executorch .backends .arm .quantizer import QuantizationConfig
15
15
from executorch .backends .arm .tosa_utils import get_node_debug_info
16
+ from torch ._subclasses import FakeTensor
16
17
17
18
from torch .fx import Node
18
19
from torchao .quantization .pt2e .quantizer import (
24
25
25
26
from .arm_quantizer_utils import (
26
27
is_annotated ,
27
- is_ok_for_quantization ,
28
28
is_output_annotated ,
29
29
mark_node_as_annotated ,
30
30
)
@@ -78,9 +78,16 @@ def _is_ok_for_quantization(
78
78
"""
79
79
# Check output
80
80
if quant_properties .quant_output is not None :
81
- if not is_ok_for_quantization (node , gm ): # type: ignore[attr-defined]
81
+ if _is_non_float_tensor (node ):
82
82
logger .debug (
83
- f"Could not quantize node due to output: "
83
+ "Could not quantize non float tensor for the following output node: "
84
+ f"{ get_node_debug_info (node , gm )} "
85
+ )
86
+
87
+ return False
88
+ elif _is_large_scalar (node , gm ):
89
+ logger .debug (
90
+ "Could not quantize large scalar node for the following output node: "
84
91
f"{ get_node_debug_info (node , gm )} "
85
92
)
86
93
@@ -99,17 +106,77 @@ def _is_ok_for_quantization(
99
106
raise TypeError (
100
107
f"n_arg must be a Node instance, got { type (n_arg ).__name__ !r} "
101
108
)
102
- if not is_ok_for_quantization (n_arg , gm ): # type: ignore[attr-defined]
109
+
110
+ if _is_non_float_tensor (n_arg ):
103
111
logger .debug (
104
- f'could not quantize node due to input "{ node } ": '
105
- f"{ get_node_debug_info (node , gm )} "
112
+ "Could not quantize non float tensor for the following input "
113
+ f"node: { get_node_debug_info (node , gm )} "
114
+ )
115
+
116
+ return False
117
+ elif _is_large_scalar (n_arg , gm ):
118
+ logger .debug (
119
+ "Could not quantize large scalar node for the following input "
120
+ f"node: { get_node_debug_info (node , gm )} "
106
121
)
107
122
108
123
return False
109
124
110
125
return True
111
126
112
127
128
+ def _get_node_target (module : torch .nn .Module | torch .fx .GraphModule , target_str : str ):
129
+ targets = target_str .split ("." )
130
+ for target in targets [:- 1 ]:
131
+ module = module .get_submodule (target )
132
+ return getattr (module , targets [- 1 ])
133
+
134
+
135
+ def _is_large_scalar (node : Node , gm : torch .fx .GraphModule ):
136
+ """Check if input is a large scalar value. So that we can skip quantization for the
137
+ node since histc op (in HistogramObserver) only works for values up to certain upper
138
+ bound.
139
+ """
140
+ if node .op == "get_attr" and isinstance (node .target , str ):
141
+ tensor = _get_node_target (gm , node .target )
142
+ # torch.histc works until this upper bound
143
+ HISTC_UPPER_BOUND = 3.4028235e15
144
+ return tensor .numel () == 1 and abs (tensor .item ()) > HISTC_UPPER_BOUND
145
+ return False
146
+
147
+
148
+ def _is_non_float_tensor (node : Node ) -> bool :
149
+ """Check if the output of a node has a data type other than `torch.float32`.
150
+
151
+ If the output is not `torch.float32`, quantization cannot be performed, as
152
+ observers only work with floating-point tensors.
153
+
154
+ Args:
155
+ node (Node): The node to check the output(s) for.
156
+
157
+ Returns:
158
+ bool: `True` if the data type is not float32, otherwise `False`.
159
+
160
+ Note:
161
+ - If `node.meta["val"]` is a `list`, the function returns `True` if **any**
162
+ element is **not** an instance of `FakeTensor` or does **not** have
163
+ `torch.float32` as its data type.
164
+ - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
165
+ function returns True.
166
+ """
167
+ if "val" in node .meta and isinstance (node .meta ["val" ], Sequence ):
168
+ return any (
169
+ not isinstance (fake_tensor , FakeTensor )
170
+ or fake_tensor .dtype != torch .float32
171
+ for fake_tensor in node .meta ["val" ]
172
+ )
173
+
174
+ if "val" not in node .meta or not isinstance (node .meta ["val" ], FakeTensor ):
175
+ return True
176
+
177
+ return node .meta ["val" ].dtype != torch .float32
178
+
179
+
113
180
def _annotate_input (node : Node , quant_property : _QuantProperty ):
114
181
if is_annotated (node ):
115
182
raise RuntimeError (
0 commit comments