|  | 
| 7 | 7 | from itertools import chain | 
| 8 | 8 | from typing import Callable, cast, Dict, Iterator, Set, Type | 
| 9 | 9 | 
 | 
|  | 10 | +import math | 
| 10 | 11 | import torch | 
| 11 | 12 | from executorch.backends.arm._passes import ArmPass | 
| 12 | 13 | from executorch.backends.arm._passes.arm_pass_utils import create_node | 
| @@ -119,9 +120,10 @@ class InsertTableOpsPass(ArmPass): | 
| 119 | 120 | 
 | 
| 120 | 121 |     _passes_required_after: Set[Type[ExportPass]] = set() | 
| 121 | 122 | 
 | 
| 122 |  | -    def __init__(self, exported_program: ExportedProgram) -> None: | 
|  | 123 | +    def __init__(self, exported_program: ExportedProgram, tosa_spec=None) -> None: | 
| 123 | 124 |         super().__init__() | 
| 124 | 125 |         self.exported_program = exported_program | 
|  | 126 | +        self.tosa_spec = tosa_spec | 
| 125 | 127 |         self.table_ops = TableOps(exported_program) | 
| 126 | 128 | 
 | 
| 127 | 129 |     def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None: | 
| @@ -157,7 +159,77 @@ def f(x: torch.Tensor) -> torch.Tensor: | 
| 157 | 159 |             0, | 
| 158 | 160 |         ) | 
| 159 | 161 | 
 | 
| 160 |  | -    def generate_16_bit_table_values( | 
|  | 162 | +    def generate_16_bit_table_values_u55_tflite( | 
|  | 163 | +        self, | 
|  | 164 | +        torch_op: Callable[[torch.Tensor], torch.Tensor], | 
|  | 165 | +        in_quantargs: QuantArgs, | 
|  | 166 | +        out_quantargs: QuantArgs, | 
|  | 167 | +    ) -> tuple[torch.Tensor, int]: | 
|  | 168 | +        """ | 
|  | 169 | +        Generate table values for U55 using TFLite-style bias correction. | 
|  | 170 | +
 | 
|  | 171 | +        1. Evaluate function at base, midpoint, and next for each interval | 
|  | 172 | +        2. Quantize all three output values | 
|  | 173 | +        3. Calculate bias = (interpolated_midpoint - actual_midpoint) / 2 | 
|  | 174 | +        4. Apply bias correction to base value | 
|  | 175 | +        5. Store corrected base values (513 values total) | 
|  | 176 | +        """ | 
|  | 177 | + | 
|  | 178 | +        # Calculate input range in FLOAT space (like TFLite) | 
|  | 179 | +        qmin_in = in_quantargs.qmin | 
|  | 180 | +        qmax_in = in_quantargs.qmax | 
|  | 181 | +        qmin_out = out_quantargs.qmin | 
|  | 182 | +        qmax_out = out_quantargs.qmax | 
|  | 183 | + | 
|  | 184 | +        input_min = in_quantargs.scale * (qmin_in - in_quantargs.zp)  # type: ignore[operator] | 
|  | 185 | +        input_max = in_quantargs.scale * (qmax_in - in_quantargs.zp)  # type: ignore[operator] | 
|  | 186 | +        output_min = out_quantargs.scale * (qmin_out - out_quantargs.zp)  # type: ignore[operator] | 
|  | 187 | +        output_max = out_quantargs.scale * (qmax_out - out_quantargs.zp)  # type: ignore[operator] | 
|  | 188 | + | 
|  | 189 | +        steps = 512 | 
|  | 190 | +        step = (input_max - input_min) / steps | 
|  | 191 | +        half_step = step / 2.0 | 
|  | 192 | +        output_scaling_inv = (qmax_out - qmin_out + 1) / (output_max - output_min) | 
|  | 193 | + | 
|  | 194 | + | 
|  | 195 | +        def f(x_float: float) -> float: | 
|  | 196 | +            x_tensor = torch.tensor([x_float], dtype=torch.float32) | 
|  | 197 | +            result = torch_op(x_tensor).item() | 
|  | 198 | + | 
|  | 199 | +            if math.isnan(result) or math.isinf(result): | 
|  | 200 | +                return input_max | 
|  | 201 | + | 
|  | 202 | +            return result | 
|  | 203 | + | 
|  | 204 | +        lut_values = [] | 
|  | 205 | + | 
|  | 206 | +        for i in range(steps + 1):  # 513 values | 
|  | 207 | +            val = f(input_min + i * step) | 
|  | 208 | +            sample_val = round(val * output_scaling_inv) | 
|  | 209 | + | 
|  | 210 | +            if i < steps: | 
|  | 211 | +                val_midpoint = f(input_min + i * step + half_step) | 
|  | 212 | +                val_next = f(input_min + (i + 1) * step) | 
|  | 213 | + | 
|  | 214 | +                midpoint_interp_val = round( | 
|  | 215 | +                    (val_next * output_scaling_inv + sample_val) / 2.0 | 
|  | 216 | +                ) | 
|  | 217 | +                midpoint_val = round(val_midpoint * output_scaling_inv) | 
|  | 218 | +                midpoint_err = midpoint_interp_val - midpoint_val | 
|  | 219 | +                bias = round(midpoint_err / 2.0) | 
|  | 220 | + | 
|  | 221 | +                clamped_lut_result = max(qmin_out, min(qmax_out, sample_val - bias)) | 
|  | 222 | +                lut_result = int(clamped_lut_result) | 
|  | 223 | + | 
|  | 224 | +                lut_values.append(lut_result) | 
|  | 225 | +            else: | 
|  | 226 | +                # Last value (i == steps): no bias correction, just quantize and clamp | 
|  | 227 | +                clamped = max(qmin_out, min(qmax_out, sample_val)) | 
|  | 228 | +                lut_values.append(int(clamped)) | 
|  | 229 | + | 
|  | 230 | +        return torch.tensor(lut_values, dtype=torch.int16).contiguous(), 0 | 
|  | 231 | + | 
|  | 232 | +    def generate_16_bit_table_values_tosa( | 
| 161 | 233 |         self, | 
| 162 | 234 |         torch_op: Callable[[torch.Tensor], torch.Tensor], | 
| 163 | 235 |         in_quantargs: QuantArgs, | 
| @@ -210,6 +282,26 @@ def f(x: torch.Tensor) -> torch.Tensor: | 
| 210 | 282 |         lut_values = lut_values >> rshift | 
| 211 | 283 |         return lut_values.to(dtype=torch.int16), rescale_lshift | 
| 212 | 284 | 
 | 
|  | 285 | +    def generate_16_bit_table_values( | 
|  | 286 | +        self, | 
|  | 287 | +        torch_op: Callable[[torch.Tensor], torch.Tensor], | 
|  | 288 | +        in_quantargs: QuantArgs, | 
|  | 289 | +        out_quantargs: QuantArgs, | 
|  | 290 | +    ) -> tuple[torch.Tensor, int]: | 
|  | 291 | +        """Compute LUT values for a INT16 tables. | 
|  | 292 | +        The function returns rescale_lshift which says how much to rescale after the table. This value can negative. | 
|  | 293 | +        """ | 
|  | 294 | + | 
|  | 295 | +        if self.tosa_spec and self.tosa_spec.is_U55_subset: | 
|  | 296 | +            # U55 needs TFLite-style table generation with bias correction | 
|  | 297 | +            return self.generate_16_bit_table_values_u55_tflite( | 
|  | 298 | +                torch_op, in_quantargs, out_quantargs | 
|  | 299 | +            ) | 
|  | 300 | +        else: | 
|  | 301 | +            return self.generate_16_bit_table_values_tosa( | 
|  | 302 | +                torch_op, in_quantargs, out_quantargs | 
|  | 303 | +            ) | 
|  | 304 | + | 
| 213 | 305 |     def generate_table_values( | 
| 214 | 306 |         self, | 
| 215 | 307 |         torch_op: Callable[[torch.Tensor], torch.Tensor], | 
| @@ -280,7 +372,15 @@ def call(self, graph_module: GraphModule) -> PassResult: | 
| 280 | 372 |                 ) | 
| 281 | 373 |                 output_node = table_op_node | 
| 282 | 374 | 
 | 
| 283 |  | -                if lshift != 0: | 
|  | 375 | +                if ( | 
|  | 376 | +                    self.tosa_spec | 
|  | 377 | +                    and self.tosa_spec.is_U55_subset | 
|  | 378 | +                    and input_qparams[0].dtype == torch.int16 | 
|  | 379 | +                ): | 
|  | 380 | +                    # U55: NO RESCALE needed - use table output directly | 
|  | 381 | +                    # Adding RESCALE creates a second operation that overwrites the table output! | 
|  | 382 | +                    output_node = table_op_node  # Use table output directly! | 
|  | 383 | +                elif lshift != 0: | 
| 284 | 384 |                     scale = 2.0**lshift | 
| 285 | 385 |                     rescale_node = create_node( | 
| 286 | 386 |                         graph=graph_module.graph, | 
|  | 
0 commit comments