77import  os 
88import  tempfile 
99import  zipfile 
10- from  typing  import  Any , Optional , Tuple 
10+ from  collections  import  defaultdict 
11+ from  typing  import  Optional , Tuple 
1112
1213import  torch 
1314
1415
16+ def  flatten_args (args ) ->  tuple  |  list :
17+     flattened_args : list  =  []
18+     if  isinstance (args , torch .Tensor ):
19+         return  [args ]
20+ 
21+     for  arg  in  args :
22+         if  isinstance (arg , (tuple , list )):
23+             flattened_args .extend (arg )
24+         else :
25+             flattened_args .append (arg )
26+ 
27+     return  tuple (flattened_args )
28+ 
29+ 
1530class  GenericModelEvaluator :
1631    def  __init__ (
1732        self ,
@@ -32,31 +47,34 @@ def __init__(
3247        else :
3348            self .tosa_output_path  =  None 
3449
35-     def  get_model_error (self ) ->  tuple [ float ,  float ,  float ,  float ] :
50+     def  get_model_error (self ) ->  defaultdict :
3651        """ 
37-         Returns the following metrics between the outputs of the FP32 and INT8 model: 
52+         Returns a dict containing  the following metrics between the outputs of the FP32 and INT8 model: 
3853        - Maximum error 
3954        - Maximum absolute error 
4055        - Maximum percentage error 
4156        - Mean absolute error 
4257        """ 
43-         fp32_output  =  self .fp32_model (* self .example_input )
44-         int8_output  =  self .int8_model (* self .example_input )
45- 
46-         difference  =  fp32_output  -  int8_output 
47-         percentage_error  =  torch .div (difference , fp32_output ) *  100 
48- 
49-         max_error  =  torch .max (difference ).item ()
50-         max_absolute_error  =  torch .max (torch .abs (difference )).item ()
51-         max_percentage_error  =  torch .max (percentage_error ).item ()
52-         mean_absolute_error  =  torch .mean (torch .abs (difference ).float ()).item ()
53- 
54-         return  (
55-             float (max_error ),
56-             float (max_absolute_error ),
57-             float (max_percentage_error ),
58-             float (mean_absolute_error ),
59-         )
58+         fp32_outputs  =  flatten_args (self .fp32_model (* self .example_input ))
59+         int8_outputs  =  flatten_args (self .int8_model (* self .example_input ))
60+ 
61+         model_error_dict  =  defaultdict (list )
62+ 
63+         for  fp32_output , int8_output  in  zip (fp32_outputs , int8_outputs ):
64+             difference  =  fp32_output  -  int8_output 
65+             percentage_error  =  torch .div (difference , fp32_output ) *  100 
66+             model_error_dict ["max_error" ].append (torch .max (difference ).item ())
67+             model_error_dict ["max_absolute_error" ].append (
68+                 torch .max (torch .abs (difference )).item ()
69+             )
70+             model_error_dict ["max_percentage_error" ].append (
71+                 torch .max (percentage_error ).item ()
72+             )
73+             model_error_dict ["mean_absolute_error" ].append (
74+                 torch .mean (torch .abs (difference ).float ()).item ()
75+             )
76+ 
77+         return  model_error_dict 
6078
6179    def  get_compression_ratio (self ) ->  float :
6280        """Compute the compression ratio of the outputted TOSA flatbuffer.""" 
@@ -72,19 +90,10 @@ def get_compression_ratio(self) -> float:
7290
7391        return  compression_ratio 
7492
75-     def  evaluate (self ) ->  dict [str , Any ]:
76-         max_error , max_absolute_error , max_percent_error , mean_absolute_error  =  (
77-             self .get_model_error ()
78-         )
79-         output_metrics  =  {
80-             "name" : self .model_name ,
81-             "metrics" : {
82-                 "max_error" : max_error ,
83-                 "max_absolute_error" : max_absolute_error ,
84-                 "max_percentage_error" : max_percent_error ,
85-                 "mean_absolute_error" : mean_absolute_error ,
86-             },
87-         }
93+     def  evaluate (self ) ->  dict [any ]:
94+         model_error_dict  =  self .get_model_error ()
95+ 
96+         output_metrics  =  {"name" : self .model_name , "metrics" : dict (model_error_dict )}
8897
8998        if  self .tosa_output_path :
9099            # We know output_metrics["metrics"] is list since we just defined it, safe to ignore. 
0 commit comments