1010from executorch .backends .xnnpack .test .tester import Tester
1111
1212
13+ def calculate_fp16_exp_tolerance (ref_output_tensor ):
14+ # Calculate mixed tolerance for float16 used in XNNPACK's float16 policy
15+ fp16_epsilon = 9.77e-4
16+ abs_tol = 2 * fp16_epsilon
17+ rel_tol = 6 * fp16_epsilon
18+
19+ ref_abs = ref_output_tensor .abs ()
20+ mixed_tol = torch .maximum (
21+ torch .full_like (ref_abs , abs_tol ),
22+ ref_abs * rel_tol ,
23+ )
24+
25+ final_atol = mixed_tol .max ().item ()
26+
27+ return final_atol , rel_tol
28+
29+
1330class TestExp (unittest .TestCase ):
1431 def setUp (self ):
1532 torch ._dynamo .reset ()
@@ -22,6 +39,16 @@ def forward(self, x):
2239 return torch .exp (x )
2340
2441 def run_exp_test (self , inputs ):
42+ input_tensor = inputs [0 ]
43+
44+ if input_tensor .dtype == torch .float16 :
45+ with torch .no_grad ():
46+ ref_output = torch .exp (input_tensor .to (torch .float32 )).to (torch .float16 )
47+ atol , rtol = calculate_fp16_exp_tolerance (ref_output )
48+ else :
49+ atol = 1e-03
50+ rtol = 1e-03
51+
2552 (
2653 Tester (self .Exp (), inputs )
2754 .export ()
@@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
3158 .check_not (["executorch_exir_dialects_edge__ops_aten_exp_default" ])
3259 .to_executorch ()
3360 .serialize ()
34- .run_method_and_compare_outputs ()
61+ .run_method_and_compare_outputs (atol = atol , rtol = rtol )
3562 )
3663
37- # TODO (leafs1): Fix flaky tests. Land fix asap
38- # and cherry-pick onto release/0.7 branch
39- @unittest .skip (reason = "For float16, numerical discepancies are too high" )
4064 def test_fp16_exp (self ):
4165 inputs = (torch .randn (20 ).to (torch .float16 ),)
4266 self .run_exp_test (inputs )
0 commit comments