|
4 | 4 | # LICENSE file in the root directory of this source tree. |
5 | 5 |
|
6 | 6 |
|
7 | | -from typing import Any, cast, List |
| 7 | +import math |
| 8 | +from typing import Any, cast, List, Tuple |
8 | 9 |
|
9 | 10 | import torch |
10 | 11 |
|
|
19 | 20 |
|
20 | 21 | from executorch.backends.arm.tosa import TosaSpecification |
21 | 22 | from executorch.backends.arm.tosa.mapping import map_dtype, TosaArg |
22 | | -from executorch.backends.arm.tosa.quant_utils import build_rescale |
23 | 23 | from torch.fx import Node |
24 | 24 |
|
25 | 25 |
|
| 26 | +# TOSA uses the RESCALE operation to scale between values with differing precision. |
| 27 | +# The RESCALE operator is defined using an integer multiply, add, and shift. |
| 28 | +# This utility function is for calculating the multiplier and shift given a scale. |
| 29 | +# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling |
| 30 | +def _compute_multiplier_and_shift( |
| 31 | + scales: list[float], scaleWidth: int = 32 |
| 32 | +) -> Tuple[list[int], list[int]]: |
| 33 | + if scaleWidth == 16: |
| 34 | + offset = 15 |
| 35 | + elif scaleWidth == 32: |
| 36 | + offset = 31 |
| 37 | + else: |
| 38 | + raise ValueError( |
| 39 | + f"Unsupported scale width: {scaleWidth}, only 16 and 32 are valid values." |
| 40 | + ) |
| 41 | + |
| 42 | + multipliers = [] |
| 43 | + shifts = [] |
| 44 | + for scale in scales: |
| 45 | + mantissa, exponent = math.frexp(scale) |
| 46 | + shift = exponent |
| 47 | + |
| 48 | + const_2_power_15_or_31 = 1 << offset |
| 49 | + shifted_mantissa = round(mantissa * const_2_power_15_or_31) |
| 50 | + |
| 51 | + assert ( |
| 52 | + shifted_mantissa <= const_2_power_15_or_31 |
| 53 | + ), f"Mantissa {shifted_mantissa} exceeds limit {const_2_power_15_or_31}" |
| 54 | + |
| 55 | + if shifted_mantissa == const_2_power_15_or_31: |
| 56 | + shifted_mantissa = shifted_mantissa // 2 |
| 57 | + shift += 1 |
| 58 | + |
| 59 | + # TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits. |
| 60 | + shift = offset - shift |
| 61 | + |
| 62 | + # INT32_MAX, 2^31 - 1 |
| 63 | + assert shifted_mantissa <= (const_2_power_15_or_31 - 1), ( |
| 64 | + f"Mantissa {shifted_mantissa} exceeds signed max " |
| 65 | + f"{const_2_power_15_or_31 - 1}" |
| 66 | + ) |
| 67 | + |
| 68 | + multiplier = shifted_mantissa |
| 69 | + |
| 70 | + if shift > 62: |
| 71 | + multiplier = multiplier >> min(31, shift - 62) |
| 72 | + shift = 62 |
| 73 | + |
| 74 | + assert multiplier >= 0, "Multiplier should be non-negative" |
| 75 | + assert shift >= 2 and shift <= 62, "Shift should be in range [2, 62]" |
| 76 | + multipliers.append(multiplier) |
| 77 | + shifts.append(shift) |
| 78 | + return multipliers, shifts |
| 79 | + |
| 80 | + |
| 81 | +# For TOSA spec v1.0 RESCALE operator requires multiplier, shifts, input_zp and output_zp to be |
| 82 | +# const inputs. Create constant operators from the data already initialized. |
| 83 | +def _create_const_ops_for_rescale( |
| 84 | + tosa_fb, |
| 85 | + scale_32, |
| 86 | + input_dtype, |
| 87 | + node_name, |
| 88 | + multipliers, |
| 89 | + shifts, |
| 90 | + input_zp, |
| 91 | + output_zp, |
| 92 | + output_dtype, |
| 93 | + ts, |
| 94 | +): |
| 95 | + |
| 96 | + multipliers = tosa_fb.addConst( |
| 97 | + (len(multipliers),), |
| 98 | + ts.DType.INT32 if scale_32 else ts.DType.INT16, |
| 99 | + multipliers, |
| 100 | + name=node_name + "_multipliers", |
| 101 | + ) |
| 102 | + shifts = tosa_fb.addConst( |
| 103 | + (len(shifts),), ts.DType.INT8, shifts, name=node_name + "_shifts" |
| 104 | + ) |
| 105 | + input_zp = tosa_fb.addConst( |
| 106 | + [1], input_dtype, input_zp, name=node_name + "_input_zp" |
| 107 | + ) |
| 108 | + output_zp = tosa_fb.addConst( |
| 109 | + [1], output_dtype, output_zp, name=node_name + "_output_zp" |
| 110 | + ) |
| 111 | + |
| 112 | + return [multipliers.name, shifts.name, input_zp.name, output_zp.name] |
| 113 | + |
| 114 | + |
| 115 | +def _build_rescale( |
| 116 | + tosa_fb: Any, |
| 117 | + scale: list[float], |
| 118 | + input_node: Any, |
| 119 | + output_name: str, |
| 120 | + output_type: Any, |
| 121 | + input_zp: list[int], |
| 122 | + output_zp: list[int], |
| 123 | + rounding_mode: ts.RoundingMode, |
| 124 | + per_channel: bool = False, |
| 125 | + is_scale32: bool = True, |
| 126 | +): |
| 127 | + scaleWidth = 16 if input_node.dtype == ts.DType.INT48 else 32 |
| 128 | + is_scale32 = False if input_node.dtype == ts.DType.INT48 else True |
| 129 | + multipliers, shifts = _compute_multiplier_and_shift(scale, scaleWidth) |
| 130 | + rescale_inputs = _create_const_ops_for_rescale( |
| 131 | + tosa_fb, |
| 132 | + is_scale32, |
| 133 | + input_node.dtype, |
| 134 | + output_name, |
| 135 | + multipliers, |
| 136 | + shifts, |
| 137 | + input_zp, |
| 138 | + output_zp, |
| 139 | + output_type, |
| 140 | + ts, |
| 141 | + ) |
| 142 | + attr_rescale = ts.TosaSerializerAttribute() |
| 143 | + attr_rescale.RescaleAttribute( |
| 144 | + scale32=is_scale32, |
| 145 | + rounding_mode=rounding_mode, |
| 146 | + per_channel=per_channel, |
| 147 | + input_unsigned=False, |
| 148 | + output_unsigned=False, |
| 149 | + ) |
| 150 | + |
| 151 | + tosa_fb.addOperator( |
| 152 | + ts.Op.RESCALE, |
| 153 | + [input_node.name, *rescale_inputs], |
| 154 | + [output_name], |
| 155 | + attr_rescale, |
| 156 | + ) |
| 157 | + |
| 158 | + |
26 | 159 | @register_node_visitor |
27 | 160 | class RescaleVisitor(NodeVisitor): |
28 | 161 | target = "tosa.RESCALE.default" |
@@ -60,7 +193,7 @@ def define_node( |
60 | 193 | f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" |
61 | 194 | ) |
62 | 195 |
|
63 | | - build_rescale( |
| 196 | + _build_rescale( |
64 | 197 | tosa_graph, |
65 | 198 | scale=scales, |
66 | 199 | input_node=inputs[0], |
|
0 commit comments