|
10 | 10 | from executorch.backends.xnnpack.test.tester import Tester |
11 | 11 |
|
12 | 12 |
|
| 13 | +def calculate_fp16_exp_tolerance(ref_output_tensor): |
| 14 | + # Calculate mixed tolerance for float16 used in XNNPACK's float16 policy |
| 15 | + fp16_epsilon = 9.77e-4 |
| 16 | + abs_tol = 2 * fp16_epsilon |
| 17 | + rel_tol = 6 * fp16_epsilon |
| 18 | + |
| 19 | + ref_abs = ref_output_tensor.abs() |
| 20 | + mixed_tol = torch.maximum( |
| 21 | + torch.full_like(ref_abs, abs_tol), |
| 22 | + ref_abs * rel_tol, |
| 23 | + ) |
| 24 | + |
| 25 | + final_atol = mixed_tol.max().item() |
| 26 | + |
| 27 | + return final_atol, rel_tol |
| 28 | + |
| 29 | + |
13 | 30 | class TestExp(unittest.TestCase): |
14 | 31 | def setUp(self): |
15 | 32 | torch._dynamo.reset() |
16 | 33 |
|
17 | 34 | class Exp(torch.nn.Module): |
18 | | - def __init__(self): |
19 | | - super().__init__() |
20 | | - |
21 | 35 | def forward(self, x): |
22 | 36 | return torch.exp(x) |
23 | 37 |
|
24 | 38 | def run_exp_test(self, inputs): |
| 39 | + model = self.Exp() |
| 40 | + input_tensor = inputs[0] |
| 41 | + |
| 42 | + if input_tensor.dtype == torch.float16: |
| 43 | + with torch.no_grad(): |
| 44 | + ref_output = torch.exp(input_tensor.to(torch.float32)).to(torch.float16) |
| 45 | + atol, rtol = calculate_fp16_exp_tolerance(ref_output) |
| 46 | + else: |
| 47 | + atol = 1e-5 |
| 48 | + rtol = 1e-5 |
| 49 | + |
25 | 50 | ( |
26 | | - Tester(self.Exp(), inputs) |
| 51 | + Tester(model, inputs) |
27 | 52 | .export() |
28 | 53 | .check_count({"torch.ops.aten.exp.default": 1}) |
29 | 54 | .to_edge_transform_and_lower() |
30 | 55 | .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
31 | 56 | .check_not(["executorch_exir_dialects_edge__ops_aten_exp_default"]) |
32 | 57 | .to_executorch() |
33 | 58 | .serialize() |
34 | | - .run_method_and_compare_outputs() |
| 59 | + .run_method_and_compare_outputs(atol=atol, rtol=rtol) |
35 | 60 | ) |
36 | 61 |
|
37 | | - # TODO (leafs1): Fix flaky tests. Land fix asap |
38 | | - # and cherry-pick onto release/0.7 branch |
39 | | - @unittest.skip(reason="For float16, numerical discepancies are too high") |
40 | 62 | def test_fp16_exp(self): |
41 | 63 | inputs = (torch.randn(20).to(torch.float16),) |
42 | 64 | self.run_exp_test(inputs) |
|
0 commit comments