Skip to content

Commit 8a3afda

Browse files
committed
Adjust tolerance for fp16 exp op to handle reasonable calculation discrepancies
1 parent 75d4b2e commit 8a3afda

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

backends/xnnpack/test/ops/test_exp.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,23 @@
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

1212

13+
def calculate_fp16_exp_tolerance(ref_output_tensor):
14+
# Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
15+
fp16_epsilon = 9.77e-4
16+
abs_tol = 2 * fp16_epsilon
17+
rel_tol = 6 * fp16_epsilon
18+
19+
ref_abs = ref_output_tensor.abs()
20+
mixed_tol = torch.maximum(
21+
torch.full_like(ref_abs, abs_tol),
22+
ref_abs * rel_tol,
23+
)
24+
25+
final_atol = mixed_tol.max().item()
26+
27+
return final_atol, rel_tol
28+
29+
1330
class TestExp(unittest.TestCase):
1431
def setUp(self):
1532
torch._dynamo.reset()
@@ -22,6 +39,16 @@ def forward(self, x):
2239
return torch.exp(x)
2340

2441
def run_exp_test(self, inputs):
42+
input_tensor = inputs[0]
43+
44+
if input_tensor.dtype == torch.float16:
45+
with torch.no_grad():
46+
ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16)
47+
atol, rtol = calculate_fp16_exp_tolerance(ref_output)
48+
else:
49+
atol = 1e-03
50+
rtol = 1e-03
51+
2552
(
2653
Tester(self.Exp(), inputs)
2754
.export()
@@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
3158
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
3259
.to_executorch()
3360
.serialize()
34-
.run_method_and_compare_outputs()
61+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3562
)
3663

37-
# TODO (leafs1): Fix flaky tests. Land fix asap
38-
# and cherry-pick onto release/0.7 branch
39-
@unittest.skip(reason="For float16, numerical discepancies are too high")
4064
def test_fp16_exp(self):
4165
inputs = (torch.randn(20).to(torch.float16),)
4266
self.run_exp_test(inputs)

backends/xnnpack/test/ops/test_gelu.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@
99
import torch
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

12+
def calculate_fp16_gelu_tolerance(ref_output_tensor):
13+
fp16_epsilon = 9.77e-4
14+
abs_tol = 2 * fp16_epsilon
15+
rel_tol = 6 * fp16_epsilon
16+
17+
ref_abs = ref_output_tensor.abs()
18+
mixed_tol = torch.maximum(
19+
torch.full_like(ref_abs, abs_tol),
20+
ref_abs * rel_tol,
21+
)
22+
23+
final_atol = mixed_tol.max().item()
24+
return final_atol, rel_tol
1225

1326
class TestGelu(unittest.TestCase):
1427
def setUp(self):
@@ -23,6 +36,16 @@ def forward(self, x):
2336
return self.gelu(x)
2437

2538
def run_gelu_test(self, inputs):
39+
input_tensor = inputs[0]
40+
41+
if input_tensor.dtype == torch.float16:
42+
with torch.no_grad():
43+
ref_output = torch.nn.functional.gelu(input_tensor.to(torch.float32)).to(torch.float16)
44+
atol, rtol = calculate_fp16_gelu_tolerance(ref_output)
45+
else:
46+
atol = 1e-03
47+
rtol = 1e-03
48+
2649
(
2750
Tester(self.Gelu(), inputs)
2851
.export()
@@ -32,7 +55,7 @@ def run_gelu_test(self, inputs):
3255
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
3356
.to_executorch()
3457
.serialize()
35-
.run_method_and_compare_outputs()
58+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3659
)
3760

3861
def test_fp16_gelu(self):

0 commit comments

Comments
 (0)