Skip to content

Commit 65976f2

Browse files
committed
Adjust tolerance for fp16 exp op to handle reasonable calculation discrepancies
1 parent 75d4b2e commit 65976f2

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

backends/xnnpack/test/ops/test_exp.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

1212

13+
def calculate_fp16_exp_tolerance(ref_output_tensor):
14+
# Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
15+
fp16_epsilon = 9.77e-4
16+
abs_tol = 2 * fp16_epsilon
17+
rel_tol = 6 * fp16_epsilon
18+
19+
ref_abs = ref_output_tensor.abs()
20+
mixed_tol = torch.maximum(
21+
torch.full_like(ref_abs, abs_tol),
22+
ref_abs * rel_tol,
23+
)
24+
25+
final_atol = mixed_tol.max().item()
26+
27+
return final_atol, rel_tol
28+
29+
1330
class TestExp(unittest.TestCase):
1431
def setUp(self):
1532
torch._dynamo.reset()
@@ -22,6 +39,16 @@ def forward(self, x):
2239
return torch.exp(x)
2340

2441
def run_exp_test(self, inputs):
42+
input_tensor = inputs[0]
43+
44+
if input_tensor.dtype == torch.float16:
45+
with torch.no_grad():
46+
ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16)
47+
atol, rtol = calculate_fp16_exp_tolerance(ref_output)
48+
else:
49+
atol = 1e-03
50+
rtol = 1e-03
51+
2552
(
2653
Tester(self.Exp(), inputs)
2754
.export()
@@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
3158
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
3259
.to_executorch()
3360
.serialize()
34-
.run_method_and_compare_outputs()
61+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3562
)
3663

37-
# TODO (leafs1): Fix flaky tests. Land fix asap
38-
# and cherry-pick onto release/0.7 branch
39-
@unittest.skip(reason="For float16, numerical discepancies are too high")
4064
def test_fp16_exp(self):
4165
inputs = (torch.randn(20).to(torch.float16),)
4266
self.run_exp_test(inputs)

backends/xnnpack/test/ops/test_gelu.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1-
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# All rights reserved.
3-
#
4-
# This source code is licensed under the BSD-style license found in the
5-
# LICENSE file in the root directory of this source tree.
6-
71
import unittest
8-
92
import torch
103
from executorch.backends.xnnpack.test.tester import Tester
114

5+
def calculate_fp16_gelu_tolerance(ref_output_tensor):
6+
fp16_epsilon = 9.77e-4
7+
abs_tol = 2 * fp16_epsilon
8+
rel_tol = 6 * fp16_epsilon
9+
10+
ref_abs = ref_output_tensor.abs()
11+
mixed_tol = torch.maximum(
12+
torch.full_like(ref_abs, abs_tol),
13+
ref_abs * rel_tol,
14+
)
15+
16+
final_atol = mixed_tol.max().item()
17+
return final_atol, rel_tol
1218

1319
class TestGelu(unittest.TestCase):
1420
def setUp(self):
@@ -23,6 +29,16 @@ def forward(self, x):
2329
return self.gelu(x)
2430

2531
def run_gelu_test(self, inputs):
32+
input_tensor = inputs[0]
33+
34+
if input_tensor.dtype == torch.float16:
35+
with torch.no_grad():
36+
ref_output = torch.nn.functional.gelu(input_tensor.to(torch.float32)).to(torch.float16)
37+
atol, rtol = calculate_fp16_gelu_tolerance(ref_output)
38+
else:
39+
atol = 1e-03
40+
rtol = 1e-03
41+
2642
(
2743
Tester(self.Gelu(), inputs)
2844
.export()
@@ -32,7 +48,7 @@ def run_gelu_test(self, inputs):
3248
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
3349
.to_executorch()
3450
.serialize()
35-
.run_method_and_compare_outputs()
51+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3652
)
3753

3854
def test_fp16_gelu(self):

0 commit comments

Comments
 (0)