44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
78import os
9+ import random
810import tempfile
911import zipfile
12+
1013from collections import defaultdict
11- from typing import Optional , Tuple
14+ from pathlib import Path
15+ from typing import Any , Optional , Tuple
1216
1317import torch
18+ from torch .nn .modules import Module
19+ from torch .utils .data import DataLoader
20+ from torchvision import datasets , transforms
21+
22+
23+ # Logger for outputting progress for longer running evaluation
24+ logger = logging .getLogger (__name__ )
25+ logger .setLevel (logging .INFO )
1426
1527
1628def flatten_args (args ) -> tuple | list :
@@ -28,6 +40,8 @@ def flatten_args(args) -> tuple | list:
2840
2941
3042class GenericModelEvaluator :
43+ REQUIRES_CONFIG = False
44+
3145 def __init__ (
3246 self ,
3347 model_name : str ,
@@ -90,7 +104,7 @@ def get_compression_ratio(self) -> float:
90104
91105 return compression_ratio
92106
93- def evaluate (self ) -> dict [any ]:
107+ def evaluate (self ) -> dict [Any ]:
94108 model_error_dict = self .get_model_error ()
95109
96110 output_metrics = {"name" : self .model_name , "metrics" : dict (model_error_dict )}
@@ -103,3 +117,93 @@ def evaluate(self) -> dict[any]:
103117 ] = self .get_compression_ratio ()
104118
105119 return output_metrics
120+
121+
122+ class MobileNetV2Evaluator (GenericModelEvaluator ):
123+ REQUIRES_CONFIG = True
124+
125+ def __init__ (
126+ self ,
127+ model_name : str ,
128+ fp32_model : Module ,
129+ int8_model : Module ,
130+ example_input : Tuple [torch .Tensor ],
131+ tosa_output_path : str | None ,
132+ batch_size : int ,
133+ validation_dataset_path : str ,
134+ ) -> None :
135+ super ().__init__ (
136+ model_name , fp32_model , int8_model , example_input , tosa_output_path
137+ )
138+
139+ self .__batch_size = batch_size
140+ self .__validation_set_path = validation_dataset_path
141+
142+ @staticmethod
143+ def __load_dataset (directory : str ) -> datasets .ImageFolder :
144+ directory_path = Path (directory )
145+ if not directory_path .exists ():
146+ raise FileNotFoundError (f"Directory: { directory } does not exist." )
147+
148+ transform = transforms .Compose (
149+ [
150+ transforms .Resize (256 ),
151+ transforms .CenterCrop (224 ),
152+ transforms .ToTensor (),
153+ transforms .Normalize (
154+ mean = [0.484 , 0.454 , 0.403 ], std = [0.225 , 0.220 , 0.220 ]
155+ ),
156+ ]
157+ )
158+ return datasets .ImageFolder (directory_path , transform = transform )
159+
160+ @staticmethod
161+ def get_calibrator (training_dataset_path : str ) -> DataLoader :
162+ dataset = MobileNetV2Evaluator .__load_dataset (training_dataset_path )
163+ rand_indices = random .sample (range (len (dataset )), k = 1000 )
164+
165+ # Return a subset of the dataset to be used for calibration
166+ return torch .utils .data .DataLoader (
167+ torch .utils .data .Subset (dataset , rand_indices ),
168+ batch_size = 1 ,
169+ shuffle = False ,
170+ )
171+
172+ def __evaluate_mobilenet (self ) -> Tuple [float , float ]:
173+ dataset = MobileNetV2Evaluator .__load_dataset (self .__validation_set_path )
174+ loaded_dataset = DataLoader (
175+ dataset ,
176+ batch_size = self .__batch_size ,
177+ shuffle = False ,
178+ )
179+
180+ top1_correct = 0
181+ top5_correct = 0
182+
183+ for i , (image , target ) in enumerate (loaded_dataset ):
184+ prediction = self .int8_model (image )
185+ top1_prediction = torch .topk (prediction , k = 1 , dim = 1 ).indices
186+ top5_prediction = torch .topk (prediction , k = 5 , dim = 1 ).indices
187+
188+ top1_correct += (top1_prediction == target .view (- 1 , 1 )).sum ().item ()
189+ top5_correct += (top5_prediction == target .view (- 1 , 1 )).sum ().item ()
190+
191+ logger .info ("Iteration: {}" .format ((i + 1 ) * self .__batch_size ))
192+ logger .info (
193+ "Top 1: {}" .format (top1_correct / ((i + 1 ) * self .__batch_size ))
194+ )
195+ logger .info (
196+ "Top 5: {}" .format (top5_correct / ((i + 1 ) * self .__batch_size ))
197+ )
198+
199+ top1_accuracy = top1_correct / len (dataset )
200+ top5_accuracy = top5_correct / len (dataset )
201+
202+ return top1_accuracy , top5_accuracy
203+
204+ def evaluate (self ) -> dict [str , Any ]:
205+ top1_correct , top5_correct = self .__evaluate_mobilenet ()
206+ output = super ().evaluate ()
207+
208+ output ["metrics" ]["accuracy" ] = {"top-1" : top1_correct , "top-5" : top5_correct }
209+ return output
0 commit comments