|
9 | 9 |
|
10 | 10 | import math |
11 | 11 |
|
12 | | -from typing import Any, cast, NamedTuple, Tuple |
| 12 | +from typing import Any, Tuple |
13 | 13 |
|
14 | 14 | import executorch.backends.arm.tosa_specification as tosa_specification |
15 | 15 |
|
16 | 16 | import torch.fx |
17 | 17 | import torch.fx.node |
18 | | -from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS |
19 | 18 |
|
20 | 19 | from executorch.backends.arm.tosa_mapping import TosaArg |
21 | | -from executorch.exir.dialects._ops import ops as exir_ops |
22 | | -from torch import Tensor |
23 | 20 | from torch.fx import Node |
24 | 21 | from tosa.RoundingMode import RoundingMode # type: ignore |
25 | 22 |
|
@@ -109,122 +106,6 @@ def insert_rescale_op_to_int8( |
109 | 106 | ) |
110 | 107 |
|
111 | 108 |
|
112 | | -class QuantArgs(NamedTuple): |
113 | | - scale: list[float] | float |
114 | | - zp: list[int] | int |
115 | | - qmin: int |
116 | | - qmax: int |
117 | | - dtype: torch.dtype |
118 | | - axis: int = 0 |
119 | | - per_channel: bool = False |
120 | | - |
121 | | - def quantize_value(self, x: torch.Tensor | float) -> Tensor: |
122 | | - """Quantizes the input tensor or value to a quantized tensor. If the input is |
123 | | - not a tensor, it is converted to a tensor first. If self.per_channel is True, |
124 | | - the quantization is done per channel, otherwise it is done per tensor. |
125 | | - """ |
126 | | - if not isinstance(x, torch.Tensor): |
127 | | - x = torch.Tensor([x]) |
128 | | - x = x.to(torch.float32) |
129 | | - if self.per_channel: |
130 | | - q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default |
131 | | - args = ( |
132 | | - x, |
133 | | - torch.tensor(self.scale), |
134 | | - torch.tensor(self.zp), |
135 | | - self.axis, |
136 | | - self.qmin, |
137 | | - self.qmax, |
138 | | - self.dtype, |
139 | | - ) |
140 | | - else: |
141 | | - q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default |
142 | | - args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment] |
143 | | - |
144 | | - return q_op(*args) |
145 | | - |
146 | | - def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor: |
147 | | - """Dequantizes the input tensor or value to a dequantized tensor If the input |
148 | | - is not a tensor, it is converted to a tensor first. If self.per_channel is True, |
149 | | - the dequantization is done per channel, otherwise it is done per tensor. |
150 | | - """ |
151 | | - if self.per_channel: |
152 | | - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default |
153 | | - args = ( |
154 | | - qx, |
155 | | - torch.tensor(self.scale), |
156 | | - torch.tensor(self.zp), |
157 | | - self.axis, |
158 | | - self.qmin, |
159 | | - self.qmax, |
160 | | - self.dtype, |
161 | | - ) |
162 | | - else: |
163 | | - dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default |
164 | | - args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment] |
165 | | - |
166 | | - return dq_op(*args) |
167 | | - |
168 | | - @classmethod |
169 | | - def from_operator(cls, op, args): |
170 | | - if op in PER_TENSOR_QDQ_OPS: |
171 | | - return cls( |
172 | | - scale=cast(float, args[1]), |
173 | | - zp=cast(int, args[2]), |
174 | | - qmin=cast(int, args[3]), |
175 | | - qmax=cast(int, args[4]), |
176 | | - dtype=cast(torch.dtype, args[5]), |
177 | | - axis=0, |
178 | | - per_channel=False, |
179 | | - ) |
180 | | - elif op in PER_CHANNEL_QDQ_OPS: |
181 | | - return cls( |
182 | | - scale=cast(list[float], args[1].tolist()), |
183 | | - zp=cast(list[int], args[2].tolist()), |
184 | | - axis=cast(int, args[3]), |
185 | | - qmin=cast(int, args[4]), |
186 | | - qmax=cast(int, args[5]), |
187 | | - dtype=cast(torch.dtype, args[6]), |
188 | | - per_channel=True, |
189 | | - ) |
190 | | - |
191 | | - else: |
192 | | - # We're only handling per tensor and per channel quantization |
193 | | - raise NotImplementedError(f"Unsupported quantization operation: {op}") |
194 | | - |
195 | | - def get_scale_per_tensor(self) -> float: |
196 | | - if not isinstance(self.scale, float): |
197 | | - raise TypeError( |
198 | | - f"Expected scale {self.scale} to be a float but found scale of " |
199 | | - f"type {type(self.scale)}" |
200 | | - ) |
201 | | - return self.scale |
202 | | - |
203 | | - def get_zp_per_tensor(self) -> int: |
204 | | - if not isinstance(self.zp, int): |
205 | | - raise TypeError( |
206 | | - f"Expected zero point {self.zp} to be an int but found zp of " |
207 | | - f"type {type(self.zp)}" |
208 | | - ) |
209 | | - return self.zp |
210 | | - |
211 | | - def get_scale_per_channel(self) -> list[float]: |
212 | | - if not isinstance(self.scale, list): |
213 | | - raise TypeError( |
214 | | - f"Expected scale {self.scale} to be a list but found scale of " |
215 | | - f"type {type(self.scale)}" |
216 | | - ) |
217 | | - return self.scale |
218 | | - |
219 | | - def get_zp_per_channel(self) -> list[int]: |
220 | | - if not isinstance(self.zp, list): |
221 | | - raise TypeError( |
222 | | - f"Expected zero point {self.zp} to be a list but found zp of " |
223 | | - f"type {type(self.zp)}" |
224 | | - ) |
225 | | - return self.zp |
226 | | - |
227 | | - |
228 | 109 | # TOSA uses the RESCALE operation to scale between values with differing precision. |
229 | 110 | # The RESCALE operator is defined using an integer multiply, add, and shift. |
230 | 111 | # This utility function is for calculating the multier and shift given a scale. |
|
0 commit comments