Skip to content

Commit ea6a363

Browse files
committed
Adjust tolerance for fp16 exp op to handle reasonable calculation discrepancies
1 parent 75d4b2e commit ea6a363

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

backends/xnnpack/test/ops/test_exp.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self):
2121
def forward(self, x):
2222
return torch.exp(x)
2323

24-
def run_exp_test(self, inputs):
24+
def run_exp_test(self, inputs, rtol=1e-03, atol=1e-03):
2525
(
2626
Tester(self.Exp(), inputs)
2727
.export()
@@ -31,15 +31,13 @@ def run_exp_test(self, inputs):
3131
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
3232
.to_executorch()
3333
.serialize()
34-
.run_method_and_compare_outputs()
34+
.run_method_and_compare_outputs(rtol=rtol, atol=atol)
3535
)
3636

37-
# TODO (leafs1): Fix flaky tests. Land fix asap
38-
# and cherry-pick onto release/0.7 branch
39-
@unittest.skip(reason="For float16, numerical discepancies are too high")
4037
def test_fp16_exp(self):
41-
inputs = (torch.randn(20).to(torch.float16),)
42-
self.run_exp_test(inputs)
38+
for _ in range(1000):
39+
inputs = (torch.randn(20).to(torch.float16),)
40+
self.run_exp_test(inputs, rtol=2e-03, atol=2e-03)
4341

4442
def test_fp32_exp(self):
4543
inputs = (torch.randn(20),)

0 commit comments

Comments
 (0)