66import logging
77import operator
88from dataclasses import dataclass
9- from typing import Callable , List , Optional
9+ from typing import Callable , List , Optional , Sequence
1010
1111import torch
1212import torch .fx
1313import torch .nn .functional as F
1414from executorch .backends .arm .quantizer import QuantizationConfig
1515from executorch .backends .arm .tosa_utils import get_node_debug_info
16+ from torch ._subclasses import FakeTensor
1617
1718from torch .fx import Node
1819from torchao .quantization .pt2e .quantizer import (
2425
2526from .arm_quantizer_utils import (
2627 is_annotated ,
27- is_ok_for_quantization ,
2828 is_output_annotated ,
2929 mark_node_as_annotated ,
3030)
@@ -78,9 +78,16 @@ def _is_ok_for_quantization(
7878 """
7979 # Check output
8080 if quant_properties .quant_output is not None :
81- if not is_ok_for_quantization (node , gm ): # type: ignore[attr-defined]
81+ if _is_non_float_tensor (node ):
8282 logger .debug (
83- f"Could not quantize node due to output: "
83+ "Could not quantize non float tensor for the following output node: "
84+ f"{ get_node_debug_info (node , gm )} "
85+ )
86+
87+ return False
88+ elif _is_large_scalar (node , gm ):
89+ logger .debug (
90+ "Could not quantize large scalar node for the following output node: "
8491 f"{ get_node_debug_info (node , gm )} "
8592 )
8693
@@ -99,17 +106,77 @@ def _is_ok_for_quantization(
99106 raise TypeError (
100107 f"n_arg must be a Node instance, got { type (n_arg ).__name__ !r} "
101108 )
102- if not is_ok_for_quantization (n_arg , gm ): # type: ignore[attr-defined]
109+
110+ if _is_non_float_tensor (n_arg ):
103111 logger .debug (
104- f'could not quantize node due to input "{ node } ": '
105- f"{ get_node_debug_info (node , gm )} "
112+ "Could not quantize non float tensor for the following input "
113+ f"node: { get_node_debug_info (node , gm )} "
114+ )
115+
116+ return False
117+ elif _is_large_scalar (n_arg , gm ):
118+ logger .debug (
119+ "Could not quantize large scalar node for the following input "
120+ f"node: { get_node_debug_info (node , gm )} "
106121 )
107122
108123 return False
109124
110125 return True
111126
112127
128+ def _get_node_target (module : torch .nn .Module | torch .fx .GraphModule , target_str : str ):
129+ targets = target_str .split ("." )
130+ for target in targets [:- 1 ]:
131+ module = module .get_submodule (target )
132+ return getattr (module , targets [- 1 ])
133+
134+
135+ def _is_large_scalar (node : Node , gm : torch .fx .GraphModule ):
136+ """Check if input is a large scalar value. So that we can skip quantization for the
137+ node since histc op (in HistogramObserver) only works for values up to certain upper
138+ bound.
139+ """
140+ if node .op == "get_attr" and isinstance (node .target , str ):
141+ tensor = _get_node_target (gm , node .target )
142+ # torch.histc works until this upper bound
143+ HISTC_UPPER_BOUND = 3.4028235e15
144+ return tensor .numel () == 1 and abs (tensor .item ()) > HISTC_UPPER_BOUND
145+ return False
146+
147+
148+ def _is_non_float_tensor (node : Node ) -> bool :
149+ """Check if the output of a node has a data type other than `torch.float32`.
150+
151+ If the output is not `torch.float32`, quantization cannot be performed, as
152+ observers only work with floating-point tensors.
153+
154+ Args:
155+ node (Node): The node to check the output(s) for.
156+
157+ Returns:
158+ bool: `True` if the data type is not float32, otherwise `False`.
159+
160+ Note:
161+ - If `node.meta["val"]` is a `list`, the function returns `True` if **any**
162+ element is **not** an instance of `FakeTensor` or does **not** have
163+ `torch.float32` as its data type.
164+ - If node.meta["val"] is missing or is not an instance of `FakeTensor`, the
165+ function returns True.
166+ """
167+ if "val" in node .meta and isinstance (node .meta ["val" ], Sequence ):
168+ return any (
169+ not isinstance (fake_tensor , FakeTensor )
170+ or fake_tensor .dtype != torch .float32
171+ for fake_tensor in node .meta ["val" ]
172+ )
173+
174+ if "val" not in node .meta or not isinstance (node .meta ["val" ], FakeTensor ):
175+ return True
176+
177+ return node .meta ["val" ].dtype != torch .float32
178+
179+
113180def _annotate_input (node : Node , quant_property : _QuantProperty ):
114181 if is_annotated (node ):
115182 raise RuntimeError (
0 commit comments