11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
3029logger .setLevel (logging .INFO )
3130
3231
32+ # ImageNet 224x224 transforms (Resize->CenterCrop->ToTensor->Normalize)
33+ # If future models require different preprocessing, extend this helper accordingly.
34+ def _get_imagenet_224_transforms ():
35+ """Return standard ImageNet 224x224 preprocessing transforms."""
36+ return transforms .Compose (
37+ [
38+ transforms .Resize (256 ),
39+ transforms .CenterCrop (224 ),
40+ transforms .ToTensor (),
41+ transforms .Normalize (mean = [0.484 , 0.454 , 0.403 ], std = [0.225 , 0.220 , 0.220 ]),
42+ ]
43+ )
44+
45+
46+ def _build_calibration_loader (
47+ dataset : datasets .ImageFolder , max_items : int
48+ ) -> DataLoader :
49+ """Return a DataLoader over a deterministic, shuffled subset of size <= max_items.
50+
51+ Shuffles with seed: ARM_EVAL_CALIB_SEED (int) or default 1337; then selects first k and
52+ sorts indices to keep enumeration order stable while content depends on seed.
53+ """
54+ k = min (max_items , len (dataset ))
55+ seed_env = os .getenv ("ARM_EVAL_CALIB_SEED" )
56+ default_seed = 1337
57+ if seed_env is not None :
58+ try :
59+ seed = int (seed_env )
60+ except ValueError :
61+ logger .warning (
62+ "ARM_EVAL_CALIB_SEED is not an int (%s); using default seed %d" ,
63+ seed_env ,
64+ default_seed ,
65+ )
66+ seed = default_seed
67+ else :
68+ seed = default_seed
69+ rng = random .Random (seed )
70+ indices = list (range (len (dataset )))
71+ rng .shuffle (indices )
72+ selected = sorted (indices [:k ])
73+ return torch .utils .data .DataLoader (
74+ torch .utils .data .Subset (dataset , selected ), batch_size = 1 , shuffle = False
75+ )
76+
77+
78+ def _load_imagenet_folder (directory : str ) -> datasets .ImageFolder :
79+ """Shared helper to load an ImageNet-layout folder.
80+
81+ Raises FileNotFoundError for a missing directory early to aid debugging.
82+ """
83+ directory_path = Path (directory )
84+ if not directory_path .exists ():
85+ raise FileNotFoundError (f"Directory: { directory } does not exist." )
86+ transform = _get_imagenet_224_transforms ()
87+ return datasets .ImageFolder (directory_path , transform = transform )
88+
89+
3390class GenericModelEvaluator :
91+ """Base evaluator computing quantization error metrics and optional compression ratio.
92+
93+ Subclasses can extend: provide calibration (get_calibrator) and override evaluate()
94+ to add domain specific metrics (e.g. top-1 / top-5 accuracy).
95+ """
96+
97+ @staticmethod
98+ def evaluate_topk (
99+ model : Module ,
100+ dataset : datasets .ImageFolder ,
101+ batch_size : int ,
102+ topk : int = 5 ,
103+ log_every : int = 50 ,
104+ ) -> Tuple [float , float ]:
105+ """Evaluate model top-1 / top-k accuracy.
106+
107+ Args:
108+ model: Torch module (should be in eval() mode prior to call).
109+ dataset: ImageFolder style dataset.
110+ batch_size: Batch size for evaluation.
111+ topk: Maximum k for accuracy (default 5).
112+ log_every: Log running accuracy every N batches.
113+ Returns:
114+ (top1_accuracy, topk_accuracy)
115+ """
116+ # Some exported / quantized models (torchao PT2E) disallow direct eval()/train().
117+ # Try to switch to eval mode, but degrade gracefully if unsupported.
118+ try :
119+ model .eval ()
120+ except NotImplementedError :
121+ # Attempt to enable train/eval overrides if torchao helper is present.
122+ try :
123+ from torchao .quantization .pt2e .utils import ( # type: ignore
124+ allow_exported_model_train_eval ,
125+ )
126+
127+ allow_exported_model_train_eval (model )
128+ try :
129+ model .eval ()
130+ except Exception :
131+ logger .debug (
132+ "Model eval still not supported after allow_exported_model_train_eval; proceeding without explicit eval()."
133+ )
134+ except Exception :
135+ logger .debug (
136+ "Model eval() unsupported and torchao allow_exported_model_train_eval not available; proceeding."
137+ )
138+ loaded_dataset = DataLoader (dataset , batch_size = batch_size , shuffle = False )
139+ top1_correct = 0
140+ topk_correct = 0
141+ total = 0
142+ with torch .inference_mode (): # disable autograd + some backend optimizations
143+ for i , (image , target ) in enumerate (loaded_dataset ):
144+ prediction = model (image )
145+ topk_indices = torch .topk (prediction , k = topk , dim = 1 ).indices
146+ # target reshaped for broadcasting
147+ target_view = target .view (- 1 , 1 )
148+ top1_correct += (topk_indices [:, :1 ] == target_view ).sum ().item ()
149+ topk_correct += (topk_indices == target_view ).sum ().item ()
150+ batch_sz = image .size (0 )
151+ total += batch_sz
152+ if (i + 1 ) % log_every == 0 or total == len (dataset ):
153+ logger .info (
154+ "Eval progress: %d / %d top1=%.4f top%d=%.4f" ,
155+ total ,
156+ len (dataset ),
157+ top1_correct / total ,
158+ topk ,
159+ topk_correct / total ,
160+ )
161+ top1_accuracy = top1_correct / len (dataset )
162+ topk_accuracy = topk_correct / len (dataset )
163+ return top1_accuracy , topk_accuracy
164+
34165 REQUIRES_CONFIG = False
35166
36167 def __init__ (
@@ -53,12 +184,13 @@ def __init__(
53184 self .tosa_output_path = ""
54185
55186 def get_model_error (self ) -> defaultdict :
56- """
57- Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
58- - Maximum error
59- - Maximum absolute error
60- - Maximum percentage error
61- - Mean absolute error
187+ """Return per-output quantization error statistics.
188+
189+ Metrics (lists per output tensor):
190+ max_error
191+ max_absolute_error
192+ max_percentage_error (safe-divided; zero fp32 elements -> 0%)
193+ mean_absolute_error
62194 """
63195 fp32_outputs , _ = tree_flatten (self .fp32_model (* self .example_input ))
64196 int8_outputs , _ = tree_flatten (self .int8_model (* self .example_input ))
@@ -67,7 +199,12 @@ def get_model_error(self) -> defaultdict:
67199
68200 for fp32_output , int8_output in zip (fp32_outputs , int8_outputs ):
69201 difference = fp32_output - int8_output
70- percentage_error = torch .div (difference , fp32_output ) * 100
202+ # Avoid divide by zero: elements where fp32 == 0 produce 0% contribution
203+ percentage_error = torch .where (
204+ fp32_output != 0 ,
205+ difference / fp32_output * 100 ,
206+ torch .zeros_like (difference ),
207+ )
71208 model_error_dict ["max_error" ].append (torch .max (difference ).item ())
72209 model_error_dict ["max_absolute_error" ].append (
73210 torch .max (torch .abs (difference )).item ()
@@ -132,77 +269,116 @@ def __init__(
132269
133270 @staticmethod
134271 def __load_dataset (directory : str ) -> datasets .ImageFolder :
135- directory_path = Path (directory )
136- if not directory_path .exists ():
137- raise FileNotFoundError (f"Directory: { directory } does not exist." )
138-
139- transform = transforms .Compose (
140- [
141- transforms .Resize (256 ),
142- transforms .CenterCrop (224 ),
143- transforms .ToTensor (),
144- transforms .Normalize (
145- mean = [0.484 , 0.454 , 0.403 ], std = [0.225 , 0.220 , 0.220 ]
146- ),
147- ]
148- )
149- return datasets .ImageFolder (directory_path , transform = transform )
272+ return _load_imagenet_folder (directory )
150273
151274 @staticmethod
152275 def get_calibrator (training_dataset_path : str ) -> DataLoader :
153276 dataset = MobileNetV2Evaluator .__load_dataset (training_dataset_path )
154- rand_indices = random . sample ( range ( len ( dataset )), k = 1000 )
277+ return _build_calibration_loader ( dataset , 1000 )
155278
156- # Return a subset of the dataset to be used for calibration
157- return torch .utils .data .DataLoader (
158- torch .utils .data .Subset (dataset , rand_indices ),
159- batch_size = 1 ,
160- shuffle = False ,
279+ @classmethod
280+ def from_config (
281+ cls ,
282+ model_name : str ,
283+ fp32_model : Module ,
284+ int8_model : Module ,
285+ example_input : Tuple [torch .Tensor ],
286+ tosa_output_path : str | None ,
287+ config : dict [str , Any ],
288+ ) -> "MobileNetV2Evaluator" :
289+ """Factory constructing evaluator from a config dict.
290+
291+ Expected keys: batch_size, validation_dataset_path
292+ """
293+ return cls (
294+ model_name ,
295+ fp32_model ,
296+ int8_model ,
297+ example_input ,
298+ tosa_output_path ,
299+ batch_size = config ["batch_size" ],
300+ validation_dataset_path = config ["validation_dataset_path" ],
161301 )
162302
163- def __evaluate_mobilenet (self ) -> Tuple [float , float ]:
303+ def evaluate (self ) -> dict [str , Any ]:
304+ # Load dataset and compute top-1 / top-5
164305 dataset = MobileNetV2Evaluator .__load_dataset (self .__validation_set_path )
165- loaded_dataset = DataLoader (
166- dataset ,
167- batch_size = self .__batch_size ,
168- shuffle = False ,
306+ top1_correct , top5_correct = GenericModelEvaluator .evaluate_topk (
307+ self .int8_model , dataset , self .__batch_size , topk = 5
169308 )
309+ output = super ().evaluate ()
170310
171- top1_correct = 0
172- top5_correct = 0
311+ output [ "metrics" ][ "accuracy" ] = { "top-1" : top1_correct , "top-5" : top5_correct }
312+ return output
173313
174- for i , (image , target ) in enumerate (loaded_dataset ):
175- prediction = self .int8_model (image )
176- top1_prediction = torch .topk (prediction , k = 1 , dim = 1 ).indices
177- top5_prediction = torch .topk (prediction , k = 5 , dim = 1 ).indices
178314
179- top1_correct += ( top1_prediction == target . view ( - 1 , 1 )). sum (). item ()
180- top5_correct += ( top5_prediction == target . view ( - 1 , 1 )). sum (). item ()
315+ class DeiTTinyEvaluator ( GenericModelEvaluator ):
316+ REQUIRES_CONFIG = True
181317
182- logger .info ("Iteration: {}" .format ((i + 1 ) * self .__batch_size ))
183- logger .info (
184- "Top 1: {}" .format (top1_correct / ((i + 1 ) * self .__batch_size ))
185- )
186- logger .info (
187- "Top 5: {}" .format (top5_correct / ((i + 1 ) * self .__batch_size ))
188- )
318+ def __init__ (
319+ self ,
320+ model_name : str ,
321+ fp32_model : Module ,
322+ int8_model : Module ,
323+ example_input : Tuple [torch .Tensor ],
324+ tosa_output_path : str | None ,
325+ batch_size : int ,
326+ validation_dataset_path : str ,
327+ ) -> None :
328+ super ().__init__ (
329+ model_name , fp32_model , int8_model , example_input , tosa_output_path
330+ )
331+ self .__batch_size = batch_size
332+ self .__validation_set_path = validation_dataset_path
189333
190- top1_accuracy = top1_correct / len (dataset )
191- top5_accuracy = top5_correct / len (dataset )
334+ @staticmethod
335+ def __load_dataset (directory : str ) -> datasets .ImageFolder :
336+ return _load_imagenet_folder (directory )
192337
193- return top1_accuracy , top5_accuracy
338+ @staticmethod
339+ def get_calibrator (training_dataset_path : str ) -> DataLoader :
340+ dataset = DeiTTinyEvaluator .__load_dataset (training_dataset_path )
341+ return _build_calibration_loader (dataset , 1000 )
342+
343+ @classmethod
344+ def from_config (
345+ cls ,
346+ model_name : str ,
347+ fp32_model : Module ,
348+ int8_model : Module ,
349+ example_input : Tuple [torch .Tensor ],
350+ tosa_output_path : str | None ,
351+ config : dict [str , Any ],
352+ ) -> "DeiTTinyEvaluator" :
353+ """Factory constructing evaluator from a config dict.
354+
355+ Expected keys: batch_size, validation_dataset_path
356+ """
357+ return cls (
358+ model_name ,
359+ fp32_model ,
360+ int8_model ,
361+ example_input ,
362+ tosa_output_path ,
363+ batch_size = config ["batch_size" ],
364+ validation_dataset_path = config ["validation_dataset_path" ],
365+ )
194366
195367 def evaluate (self ) -> dict [str , Any ]:
196- top1_correct , top5_correct = self .__evaluate_mobilenet ()
368+ # Load dataset and compute top-1 / top-5
369+ dataset = DeiTTinyEvaluator .__load_dataset (self .__validation_set_path )
370+ top1 , top5 = GenericModelEvaluator .evaluate_topk (
371+ self .int8_model , dataset , self .__batch_size , topk = 5
372+ )
197373 output = super ().evaluate ()
198-
199- output ["metrics" ]["accuracy" ] = {"top-1" : top1_correct , "top-5" : top5_correct }
374+ output ["metrics" ]["accuracy" ] = {"top-1" : top1 , "top-5" : top5 }
200375 return output
201376
202377
203378evaluators : dict [str , type [GenericModelEvaluator ]] = {
204379 "generic" : GenericModelEvaluator ,
205380 "mv2" : MobileNetV2Evaluator ,
381+ "deit_tiny" : DeiTTinyEvaluator ,
206382}
207383
208384
@@ -223,6 +399,10 @@ def evaluator_calibration_data(
223399 return evaluator .get_calibrator (
224400 training_dataset_path = config ["training_dataset_path" ]
225401 )
402+ if evaluator is DeiTTinyEvaluator :
403+ return evaluator .get_calibrator (
404+ training_dataset_path = config ["training_dataset_path" ]
405+ )
226406 else :
227407 raise RuntimeError (f"Unknown evaluator: { evaluator_name } " )
228408
@@ -238,30 +418,30 @@ def evaluate_model(
238418) -> None :
239419 evaluator = evaluators [evaluator_name ]
240420
241- # Get the path of the TOSA flatbuffer that is dumped
242421 intermediates_path = Path (intermediates )
243422 tosa_paths = list (intermediates_path .glob ("*.tosa" ))
244423
245424 if evaluator .REQUIRES_CONFIG :
246425 assert evaluator_config is not None
247-
248426 config_path = Path (evaluator_config )
249427 with config_path .open () as f :
250428 config = json .load (f )
251429
252- if evaluator == MobileNetV2Evaluator :
253- mv2_evaluator = cast (type [MobileNetV2Evaluator ], evaluator )
254- init_evaluator : GenericModelEvaluator = mv2_evaluator (
430+ # Prefer a subclass provided from_config if available.
431+ if hasattr (evaluator , "from_config" ):
432+ factory = cast (Any , evaluator .from_config ) # type: ignore[attr-defined]
433+ init_evaluator = factory (
255434 model_name ,
256435 model_fp32 ,
257436 model_int8 ,
258437 example_inputs ,
259438 str (tosa_paths [0 ]),
260- batch_size = config ["batch_size" ],
261- validation_dataset_path = config ["validation_dataset_path" ],
439+ config ,
262440 )
263441 else :
264- raise RuntimeError (f"Unknown evaluator { evaluator_name } " )
442+ raise RuntimeError (
443+ f"Evaluator { evaluator_name } requires config but does not implement from_config()"
444+ )
265445 else :
266446 init_evaluator = evaluator (
267447 model_name , model_fp32 , model_int8 , example_inputs , str (tosa_paths [0 ])
0 commit comments