|
7 | 7 | from itertools import chain |
8 | 8 | from typing import Callable, cast, Dict, Iterator, Set, Type |
9 | 9 |
|
| 10 | +import math |
10 | 11 | import torch |
11 | 12 | from executorch.backends.arm._passes import ArmPass |
12 | 13 | from executorch.backends.arm._passes.arm_pass_utils import create_node |
@@ -119,9 +120,10 @@ class InsertTableOpsPass(ArmPass): |
119 | 120 |
|
120 | 121 | _passes_required_after: Set[Type[ExportPass]] = set() |
121 | 122 |
|
122 | | - def __init__(self, exported_program: ExportedProgram) -> None: |
| 123 | + def __init__(self, exported_program: ExportedProgram, tosa_spec=None) -> None: |
123 | 124 | super().__init__() |
124 | 125 | self.exported_program = exported_program |
| 126 | + self.tosa_spec = tosa_spec |
125 | 127 | self.table_ops = TableOps(exported_program) |
126 | 128 |
|
127 | 129 | def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None: |
@@ -157,7 +159,77 @@ def f(x: torch.Tensor) -> torch.Tensor: |
157 | 159 | 0, |
158 | 160 | ) |
159 | 161 |
|
160 | | - def generate_16_bit_table_values( |
| 162 | + def generate_16_bit_table_values_u55_tflite( |
| 163 | + self, |
| 164 | + torch_op: Callable[[torch.Tensor], torch.Tensor], |
| 165 | + in_quantargs: QuantArgs, |
| 166 | + out_quantargs: QuantArgs, |
| 167 | + ) -> tuple[torch.Tensor, int]: |
| 168 | + """ |
| 169 | + Generate table values for U55 using TFLite-style bias correction. |
| 170 | +
|
| 171 | + 1. Evaluate function at base, midpoint, and next for each interval |
| 172 | + 2. Quantize all three output values |
| 173 | + 3. Calculate bias = (interpolated_midpoint - actual_midpoint) / 2 |
| 174 | + 4. Apply bias correction to base value |
| 175 | + 5. Store corrected base values (513 values total) |
| 176 | + """ |
| 177 | + |
| 178 | + # Calculate input range in FLOAT space (like TFLite) |
| 179 | + qmin_in = in_quantargs.qmin |
| 180 | + qmax_in = in_quantargs.qmax |
| 181 | + qmin_out = out_quantargs.qmin |
| 182 | + qmax_out = out_quantargs.qmax |
| 183 | + |
| 184 | + input_min = in_quantargs.scale * (qmin_in - in_quantargs.zp) # type: ignore[operator] |
| 185 | + input_max = in_quantargs.scale * (qmax_in - in_quantargs.zp) # type: ignore[operator] |
| 186 | + output_min = out_quantargs.scale * (qmin_out - out_quantargs.zp) # type: ignore[operator] |
| 187 | + output_max = out_quantargs.scale * (qmax_out - out_quantargs.zp) # type: ignore[operator] |
| 188 | + |
| 189 | + steps = 512 |
| 190 | + step = (input_max - input_min) / steps |
| 191 | + half_step = step / 2.0 |
| 192 | + output_scaling_inv = (qmax_out - qmin_out + 1) / (output_max - output_min) |
| 193 | + |
| 194 | + |
| 195 | + def f(x_float: float) -> float: |
| 196 | + x_tensor = torch.tensor([x_float], dtype=torch.float32) |
| 197 | + result = torch_op(x_tensor).item() |
| 198 | + |
| 199 | + if math.isnan(result) or math.isinf(result): |
| 200 | + return input_max |
| 201 | + |
| 202 | + return result |
| 203 | + |
| 204 | + lut_values = [] |
| 205 | + |
| 206 | + for i in range(steps + 1): # 513 values |
| 207 | + val = f(input_min + i * step) |
| 208 | + sample_val = round(val * output_scaling_inv) |
| 209 | + |
| 210 | + if i < steps: |
| 211 | + val_midpoint = f(input_min + i * step + half_step) |
| 212 | + val_next = f(input_min + (i + 1) * step) |
| 213 | + |
| 214 | + midpoint_interp_val = round( |
| 215 | + (val_next * output_scaling_inv + sample_val) / 2.0 |
| 216 | + ) |
| 217 | + midpoint_val = round(val_midpoint * output_scaling_inv) |
| 218 | + midpoint_err = midpoint_interp_val - midpoint_val |
| 219 | + bias = round(midpoint_err / 2.0) |
| 220 | + |
| 221 | + clamped_lut_result = max(qmin_out, min(qmax_out, sample_val - bias)) |
| 222 | + lut_result = int(clamped_lut_result) |
| 223 | + |
| 224 | + lut_values.append(lut_result) |
| 225 | + else: |
| 226 | + # Last value (i == steps): no bias correction, just quantize and clamp |
| 227 | + clamped = max(qmin_out, min(qmax_out, sample_val)) |
| 228 | + lut_values.append(int(clamped)) |
| 229 | + |
| 230 | + return torch.tensor(lut_values, dtype=torch.int16).contiguous(), 0 |
| 231 | + |
| 232 | + def generate_16_bit_table_values_tosa( |
161 | 233 | self, |
162 | 234 | torch_op: Callable[[torch.Tensor], torch.Tensor], |
163 | 235 | in_quantargs: QuantArgs, |
@@ -210,6 +282,26 @@ def f(x: torch.Tensor) -> torch.Tensor: |
210 | 282 | lut_values = lut_values >> rshift |
211 | 283 | return lut_values.to(dtype=torch.int16), rescale_lshift |
212 | 284 |
|
| 285 | + def generate_16_bit_table_values( |
| 286 | + self, |
| 287 | + torch_op: Callable[[torch.Tensor], torch.Tensor], |
| 288 | + in_quantargs: QuantArgs, |
| 289 | + out_quantargs: QuantArgs, |
| 290 | + ) -> tuple[torch.Tensor, int]: |
| 291 | + """Compute LUT values for a INT16 tables. |
| 292 | + The function returns rescale_lshift which says how much to rescale after the table. This value can negative. |
| 293 | + """ |
| 294 | + |
| 295 | + if self.tosa_spec and self.tosa_spec.is_U55_subset: |
| 296 | + # U55 needs TFLite-style table generation with bias correction |
| 297 | + return self.generate_16_bit_table_values_u55_tflite( |
| 298 | + torch_op, in_quantargs, out_quantargs |
| 299 | + ) |
| 300 | + else: |
| 301 | + return self.generate_16_bit_table_values_tosa( |
| 302 | + torch_op, in_quantargs, out_quantargs |
| 303 | + ) |
| 304 | + |
213 | 305 | def generate_table_values( |
214 | 306 | self, |
215 | 307 | torch_op: Callable[[torch.Tensor], torch.Tensor], |
@@ -280,7 +372,15 @@ def call(self, graph_module: GraphModule) -> PassResult: |
280 | 372 | ) |
281 | 373 | output_node = table_op_node |
282 | 374 |
|
283 | | - if lshift != 0: |
| 375 | + if ( |
| 376 | + self.tosa_spec |
| 377 | + and self.tosa_spec.is_U55_subset |
| 378 | + and input_qparams[0].dtype == torch.int16 |
| 379 | + ): |
| 380 | + # U55: NO RESCALE needed - use table output directly |
| 381 | + # Adding RESCALE creates a second operation that overwrites the table output! |
| 382 | + output_node = table_op_node # Use table output directly! |
| 383 | + elif lshift != 0: |
284 | 384 | scale = 2.0**lshift |
285 | 385 | rescale_node = create_node( |
286 | 386 | graph=graph_module.graph, |
|
0 commit comments