Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions backends/xnnpack/test/ops/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,23 @@
from executorch.backends.xnnpack.test.tester import Tester


def calculate_fp16_exp_tolerance(ref_output_tensor):
# Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
fp16_epsilon = 9.77e-4
abs_tol = 2 * fp16_epsilon
rel_tol = 6 * fp16_epsilon

ref_abs = ref_output_tensor.abs()
mixed_tol = torch.maximum(
torch.full_like(ref_abs, abs_tol),
ref_abs * rel_tol,
)

final_atol = mixed_tol.max().item()

return final_atol, rel_tol


class TestExp(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()
Expand All @@ -22,6 +39,16 @@ def forward(self, x):
return torch.exp(x)

def run_exp_test(self, inputs):
input_tensor = inputs[0]

if input_tensor.dtype == torch.float16:
with torch.no_grad():
ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16)
atol, rtol = calculate_fp16_exp_tolerance(ref_output)
else:
atol = 1e-03
rtol = 1e-03

(
Tester(self.Exp(), inputs)
.export()
Expand All @@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
.check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

# TODO (leafs1): Fix flaky tests. Land fix asap
# and cherry-pick onto release/0.7 branch
@unittest.skip(reason="For float16, numerical discepancies are too high")
def test_fp16_exp(self):
inputs = (torch.randn(20).to(torch.float16),)
self.run_exp_test(inputs)
Expand Down
29 changes: 28 additions & 1 deletion backends/xnnpack/test/ops/test_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
from executorch.backends.xnnpack.test.tester import Tester


def calculate_fp16_gelu_tolerance(ref_output_tensor):
fp16_epsilon = 9.77e-4
abs_tol = 2 * fp16_epsilon
rel_tol = 6 * fp16_epsilon

ref_abs = ref_output_tensor.abs()
mixed_tol = torch.maximum(
torch.full_like(ref_abs, abs_tol),
ref_abs * rel_tol,
)

final_atol = mixed_tol.max().item()
return final_atol, rel_tol


class TestGelu(unittest.TestCase):
def setUp(self):
torch._dynamo.reset()
Expand All @@ -23,6 +38,18 @@ def forward(self, x):
return self.gelu(x)

def run_gelu_test(self, inputs):
input_tensor = inputs[0]

if input_tensor.dtype == torch.float16:
with torch.no_grad():
ref_output = torch.nn.functional.gelu(
input_tensor.to(torch.float32)
).to(torch.float16)
atol, rtol = calculate_fp16_gelu_tolerance(ref_output)
else:
atol = 1e-03
rtol = 1e-03

(
Tester(self.Gelu(), inputs)
.export()
Expand All @@ -32,7 +59,7 @@ def run_gelu_test(self, inputs):
.check_not(["executorch_exir_dialects_edge__ops_aten_gelu_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
.run_method_and_compare_outputs(atol=atol, rtol=rtol)
)

def test_fp16_gelu(self):
Expand Down
Loading