Skip to content

Commit 76c7ac7

Browse files
3l1facebook-github-bot
authored andcommitted
Fix U55 int16 table generation (rsqrt, sigmoid, tanh)
Summary: This diff fixes critical runtime bugs in U55 INT16 table operations (rsqrt, sigmoid, tanh) **WARNING: This diff goes with the Regor diff D85535937 and is only split because it maps to a separate OSS github repo (the Arm Regor git repo)** bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Differential Revision: D85312140
1 parent 57ffbf6 commit 76c7ac7

File tree

5 files changed

+101
-16
lines changed

5 files changed

+101
-16
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
200200

201201
self.add_pass(FuseViewCopyTransform())
202202
self.add_pass(FuseConstantArgsPass(exported_program))
203-
self.add_pass(InsertTableOpsPass(exported_program))
203+
self.add_pass(InsertTableOpsPass(exported_program, self.tosa_spec))
204204
# If we have a conv2d with int16 activation split up into a convolution
205205
# and an addition, to work-around the lack of support for int48 in torch
206206
# needs to happen before RewriteConv2dPass, but after the table ops are inserted
@@ -297,7 +297,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
297297
self.add_pass(RewriteConv2dPass(exported_program))
298298
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
299299
self.add_pass(RewriteUpsamplePass())
300-
self.add_pass(InsertTableOpsPass(exported_program))
300+
self.add_pass(InsertTableOpsPass(exported_program, self.tosa_spec))
301301
self.add_pass(RewriteMatmulPass())
302302
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
303303
self.add_pass(ToTosaMemoryFormatPass(exported_program))

backends/arm/_passes/insert_table_ops.py

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ class InsertTableOpsPass(ArmPass):
119119

120120
_passes_required_after: Set[Type[ExportPass]] = set()
121121

122-
def __init__(self, exported_program: ExportedProgram) -> None:
122+
def __init__(self, exported_program: ExportedProgram, tosa_spec=None) -> None:
123123
super().__init__()
124124
self.exported_program = exported_program
125+
self.tosa_spec = tosa_spec
125126
self.table_ops = TableOps(exported_program)
126127

127128
def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
@@ -157,6 +158,82 @@ def f(x: torch.Tensor) -> torch.Tensor:
157158
0,
158159
)
159160

161+
def generate_16_bit_table_values_u55_tflite_style(
162+
self,
163+
torch_op: Callable[[torch.Tensor], torch.Tensor],
164+
in_quantargs: QuantArgs,
165+
out_quantargs: QuantArgs,
166+
) -> tuple[torch.Tensor, int]:
167+
"""
168+
Generate table values for U55 using U55-style bias correction.
169+
170+
1. Evaluate function at base, midpoint, and next for each interval IN FLOAT SPACE
171+
2. Quantize all three output values
172+
3. Calculate bias = (interpolated_midpoint - actual_midpoint) / 2
173+
4. Apply bias correction to base value
174+
5. Store corrected base values (513 values total)
175+
"""
176+
import math
177+
178+
# Debug: Check if this function is being called
179+
180+
# Calculate input range in FLOAT space (like TFLite)
181+
qmin_in = in_quantargs.qmin
182+
qmax_in = in_quantargs.qmax
183+
qmin_out = out_quantargs.qmin
184+
qmax_out = out_quantargs.qmax
185+
186+
input_min = in_quantargs.scale * (qmin_in - in_quantargs.zp)
187+
input_max = in_quantargs.scale * (qmax_in - in_quantargs.zp)
188+
output_min = out_quantargs.scale * (qmin_out - out_quantargs.zp)
189+
output_max = out_quantargs.scale * (qmax_out - out_quantargs.zp)
190+
191+
steps = 512
192+
step = (input_max - input_min) / steps
193+
half_step = step / 2.0
194+
output_scaling_inv = (qmax_out - qmin_out + 1) / (output_max - output_min)
195+
196+
197+
def f(x_float: float) -> float:
198+
"""Evaluate torch_op at x_float, handling NaN/inf."""
199+
x_tensor = torch.tensor([x_float], dtype=torch.float32)
200+
result = torch_op(x_tensor).item()
201+
202+
if math.isnan(result) or math.isinf(result):
203+
return input_max # Will quantize to qmax_out
204+
205+
return result
206+
207+
lut_values = []
208+
209+
for i in range(steps + 1): # 513 values (0 to 512)
210+
val = f(input_min + i * step)
211+
sample_val = round(val * output_scaling_inv)
212+
213+
if i < steps:
214+
val_midpoint = f(input_min + i * step + half_step)
215+
val_next = f(input_min + (i + 1) * step)
216+
217+
midpoint_interp_val = round(
218+
(val_next * output_scaling_inv + sample_val) / 2.0
219+
)
220+
midpoint_val = round(val_midpoint * output_scaling_inv)
221+
midpoint_err = midpoint_interp_val - midpoint_val
222+
bias = round(midpoint_err / 2.0)
223+
224+
clamped_lut_result = max(qmin_out, min(qmax_out, sample_val - bias))
225+
lut_result = int(clamped_lut_result)
226+
227+
lut_values.append(lut_result)
228+
else:
229+
# Last value (i == steps): no bias correction, just quantize and clamp
230+
clamped = max(qmin_out, min(qmax_out, sample_val))
231+
lut_values.append(int(clamped))
232+
233+
buffer = torch.tensor(lut_values, dtype=torch.int16).contiguous()
234+
235+
return buffer, 0
236+
160237
def generate_16_bit_table_values(
161238
self,
162239
torch_op: Callable[[torch.Tensor], torch.Tensor],
@@ -178,6 +255,12 @@ def generate_16_bit_table_values(
178255
The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
179256
"""
180257

258+
# U55 needs TFLite-style table generation with bias correction
259+
if self.tosa_spec is not None and self.tosa_spec.is_U55_subset:
260+
return self.generate_16_bit_table_values_u55_tflite_style(
261+
torch_op, in_quantargs, out_quantargs
262+
)
263+
181264
def f(x: torch.Tensor) -> torch.Tensor:
182265
x = x.clamp(in_quantargs.qmin, in_quantargs.qmax).to(
183266
dtype=in_quantargs.dtype
@@ -280,7 +363,15 @@ def call(self, graph_module: GraphModule) -> PassResult:
280363
)
281364
output_node = table_op_node
282365

283-
if lshift != 0:
366+
if (
367+
self.tosa_spec
368+
and self.tosa_spec.is_U55_subset
369+
and input_qparams[0].dtype == torch.int16
370+
):
371+
# U55: NO RESCALE needed - use table output directly
372+
# Adding RESCALE creates a second operation that overwrites the table output!
373+
output_node = table_op_node # Use table output directly!
374+
elif lshift != 0:
284375
scale = 2.0**lshift
285376
rescale_node = create_node(
286377
graph=graph_module.graph,

backends/arm/test/ops/test_rsqrt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def test_rsqrt_int16_tosa_INT(test_tensor: torch.Tensor):
156156

157157
@common.parametrize("test_tensor", Rsqrt.test_parameters)
158158
@common.XfailIfNoCorstone300
159-
@pytest.mark.xfail(
160-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
161-
)
159+
# @pytest.mark.xfail(
160+
# reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
161+
# )
162162
def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
163163
"""Test rsqrt operation with int16 quantization on U55"""
164164
pipeline = EthosU55PipelineINT[input_t1](
@@ -182,9 +182,9 @@ def test_rsqrt_int16_u55_INT16(test_tensor: torch.Tensor):
182182

183183
@common.parametrize("test_tensor", Rsqrt.test_parameters)
184184
@common.XfailIfNoCorstone320
185-
@pytest.mark.xfail(
186-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
187-
)
185+
# @pytest.mark.xfail(
186+
# reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
187+
# )
188188
def test_rsqrt_int16_u85_INT16(test_tensor: torch.Tensor):
189189
"""Test rsqrt operation with int16 quantization on U85"""
190190
pipeline = EthosU85PipelineINT[input_t1](

backends/arm/test/ops/test_sigmoid.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,6 @@ def test_sigmoid_16a8w_tosa_INT(test_data: torch.Tensor):
312312

313313
@common.parametrize("test_data", test_data_suite)
314314
@common.XfailIfNoCorstone300
315-
@pytest.mark.xfail(
316-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
317-
)
318315
def test_sigmoid_16a8w_u55_INT16(test_data: torch.Tensor):
319316
"""Test sigmoid operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
320317
per_channel_quantization = False

backends/arm/test/ops/test_tanh.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,6 @@ def test_tanh_16a8w_tosa_INT(test_data: torch.Tensor):
163163

164164
@common.parametrize("test_data", test_data_suite)
165165
@common.XfailIfNoCorstone300
166-
@pytest.mark.xfail(
167-
reason="MLETORCH-707: AssertionError: Output 0 does not match reference output."
168-
)
169166
def test_tanh_16a8w_u55_INT16(test_data: torch.Tensor):
170167
"""Test tanh operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)"""
171168
per_channel_quantization = False

0 commit comments

Comments
 (0)