1010from  executorch .backends .xnnpack .test .tester  import  Tester 
1111
1212
13+ def  calculate_fp16_exp_tolerance (ref_output_tensor ):
14+     # Calculate mixed tolerance for float16 used in XNNPACK's float16 policy 
15+     fp16_epsilon  =  9.77e-4 
16+     abs_tol  =  2  *  fp16_epsilon 
17+     rel_tol  =  6  *  fp16_epsilon 
18+ 
19+     ref_abs  =  ref_output_tensor .abs ()
20+     mixed_tol  =  torch .maximum (
21+         torch .full_like (ref_abs , abs_tol ),
22+         ref_abs  *  rel_tol ,
23+     )
24+ 
25+     final_atol  =  mixed_tol .max ().item ()
26+ 
27+     return  final_atol , rel_tol 
28+ 
29+ 
1330class  TestExp (unittest .TestCase ):
1431    def  setUp (self ):
1532        torch ._dynamo .reset ()
@@ -22,6 +39,16 @@ def forward(self, x):
2239            return  torch .exp (x )
2340
2441    def  run_exp_test (self , inputs ):
42+         input_tensor  =  inputs [0 ]
43+ 
44+         if  input_tensor .dtype  ==  torch .float16 :
45+             with  torch .no_grad ():
46+                 ref_output  =  torch .exp (input_tensor .to (torch .float32 )).to (torch .float16 )
47+             atol , rtol  =  calculate_fp16_exp_tolerance (ref_output )
48+         else :
49+             atol  =  1e-03 
50+             rtol  =  1e-03 
51+ 
2552        (
2653            Tester (self .Exp (), inputs )
2754            .export ()
@@ -31,12 +58,9 @@ def run_exp_test(self, inputs):
3158            .check_not (["executorch_exir_dialects_edge__ops_aten_exp_default" ])
3259            .to_executorch ()
3360            .serialize ()
34-             .run_method_and_compare_outputs ()
61+             .run_method_and_compare_outputs (atol = atol ,  rtol = rtol )
3562        )
3663
37-     # TODO (leafs1): Fix flaky tests. Land fix asap 
38-     # and cherry-pick onto release/0.7 branch 
39-     @unittest .skip (reason = "For float16, numerical discepancies are too high" ) 
4064    def  test_fp16_exp (self ):
4165        inputs  =  (torch .randn (20 ).to (torch .float16 ),)
4266        self .run_exp_test (inputs )
0 commit comments