Skip to content

Commit 0972df0

Browse files
committed
Adjust tolerance for fp16 exp op to handle reasonable calculation discrepancies
1 parent 75d4b2e commit 0972df0

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

backends/xnnpack/test/ops/test_exp.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,55 @@
1010
from executorch.backends.xnnpack.test.tester import Tester
1111

1212

13+
def calculate_fp16_exp_tolerance(ref_output_tensor):
14+
# Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
15+
fp16_epsilon = 9.77e-4
16+
abs_tol = 2 * fp16_epsilon
17+
rel_tol = 6 * fp16_epsilon
18+
19+
ref_abs = ref_output_tensor.abs()
20+
mixed_tol = torch.maximum(
21+
torch.full_like(ref_abs, abs_tol),
22+
ref_abs * rel_tol,
23+
)
24+
25+
final_atol = mixed_tol.max().item()
26+
27+
return final_atol, rel_tol
28+
29+
1330
class TestExp(unittest.TestCase):
1431
def setUp(self):
1532
torch._dynamo.reset()
1633

1734
class Exp(torch.nn.Module):
18-
def __init__(self):
19-
super().__init__()
20-
2135
def forward(self, x):
2236
return torch.exp(x)
2337

2438
def run_exp_test(self, inputs):
39+
model = self.Exp()
40+
input_tensor = inputs[0]
41+
42+
if input_tensor.dtype == torch.float16:
43+
with torch.no_grad():
44+
ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16)
45+
atol, rtol = calculate_fp16_exp_tolerance(ref_output)
46+
else:
47+
atol = 1e-5
48+
rtol = 1e-5
49+
2550
(
26-
Tester(self.Exp(), inputs)
51+
Tester(model, inputs)
2752
.export()
2853
.check_count({"torch.ops.aten.exp.default": 1})
2954
.to_edge_transform_and_lower()
3055
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
3156
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
3257
.to_executorch()
3358
.serialize()
34-
.run_method_and_compare_outputs()
59+
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
3560
)
3661

37-
# TODO (leafs1): Fix flaky tests. Land fix asap
38-
# and cherry-pick onto release/0.7 branch
39-
@unittest.skip(reason="For float16, numerical discepancies are too high")
4062
def test_fp16_exp(self):
4163
inputs = (torch.randn(20).to(torch.float16),)
4264
self.run_exp_test(inputs)

0 commit comments

Comments
 (0)