@@ -167,14 +167,14 @@ def __init__(
167167 self ,
168168 model_name : str ,
169169 fp32_model : torch .nn .Module ,
170- int8_model : torch .nn .Module ,
170+ quant_model : torch .nn .Module ,
171171 example_input : Tuple [torch .Tensor ],
172172 tosa_output_path : Optional [str ],
173173 ) -> None :
174174 self .model_name = model_name
175175
176176 self .fp32_model = fp32_model
177- self .int8_model = int8_model
177+ self .quant_model = quant_model
178178 self .example_input = example_input
179179
180180 if tosa_output_path :
@@ -192,12 +192,12 @@ def get_model_error(self) -> defaultdict:
192192 mean_absolute_error
193193 """
194194 fp32_outputs , _ = tree_flatten (self .fp32_model (* self .example_input ))
195- int8_outputs , _ = tree_flatten (self .int8_model (* self .example_input ))
195+ quant_outputs , _ = tree_flatten (self .quant_model (* self .example_input ))
196196
197197 model_error_dict = defaultdict (list )
198198
199- for fp32_output , int8_output in zip (fp32_outputs , int8_outputs ):
200- difference = fp32_output - int8_output
199+ for fp32_output , quant_output in zip (fp32_outputs , quant_outputs ):
200+ difference = fp32_output - quant_output
201201 # Avoid divide by zero: elements where fp32 == 0 produce 0% contribution
202202 percentage_error = torch .where (
203203 fp32_output != 0 ,
@@ -252,14 +252,14 @@ def __init__(
252252 self ,
253253 model_name : str ,
254254 fp32_model : Module ,
255- int8_model : Module ,
255+ quant_model : Module ,
256256 example_input : Tuple [torch .Tensor ],
257257 tosa_output_path : str | None ,
258258 batch_size : int ,
259259 validation_dataset_path : str ,
260260 ) -> None :
261261 super ().__init__ (
262- model_name , fp32_model , int8_model , example_input , tosa_output_path
262+ model_name , fp32_model , quant_model , example_input , tosa_output_path
263263 )
264264
265265 self .__batch_size = batch_size
@@ -279,7 +279,7 @@ def from_config(
279279 cls ,
280280 model_name : str ,
281281 fp32_model : Module ,
282- int8_model : Module ,
282+ quant_model : Module ,
283283 example_input : Tuple [torch .Tensor ],
284284 tosa_output_path : str | None ,
285285 config : dict [str , Any ],
@@ -291,7 +291,7 @@ def from_config(
291291 return cls (
292292 model_name ,
293293 fp32_model ,
294- int8_model ,
294+ quant_model ,
295295 example_input ,
296296 tosa_output_path ,
297297 batch_size = config ["batch_size" ],
@@ -302,10 +302,9 @@ def evaluate(self) -> dict[str, Any]:
302302 # Load dataset and compute top-1 / top-5
303303 dataset = MobileNetV2Evaluator .__load_dataset (self .__validation_set_path )
304304 top1_correct , top5_correct = GenericModelEvaluator .evaluate_topk (
305- self .int8_model , dataset , self .__batch_size , topk = 5
305+ self .quant_model , dataset , self .__batch_size , topk = 5
306306 )
307307 output = super ().evaluate ()
308-
309308 output ["metrics" ]["accuracy" ] = {"top-1" : top1_correct , "top-5" : top5_correct }
310309 return output
311310
@@ -317,14 +316,14 @@ def __init__(
317316 self ,
318317 model_name : str ,
319318 fp32_model : Module ,
320- int8_model : Module ,
319+ quant_model : Module ,
321320 example_input : Tuple [torch .Tensor ],
322321 tosa_output_path : str | None ,
323322 batch_size : int ,
324323 validation_dataset_path : str ,
325324 ) -> None :
326325 super ().__init__ (
327- model_name , fp32_model , int8_model , example_input , tosa_output_path
326+ model_name , fp32_model , quant_model , example_input , tosa_output_path
328327 )
329328 self .__batch_size = batch_size
330329 self .__validation_set_path = validation_dataset_path
@@ -343,7 +342,7 @@ def from_config(
343342 cls ,
344343 model_name : str ,
345344 fp32_model : Module ,
346- int8_model : Module ,
345+ quant_model : Module ,
347346 example_input : Tuple [torch .Tensor ],
348347 tosa_output_path : str | None ,
349348 config : dict [str , Any ],
@@ -355,7 +354,7 @@ def from_config(
355354 return cls (
356355 model_name ,
357356 fp32_model ,
358- int8_model ,
357+ quant_model ,
359358 example_input ,
360359 tosa_output_path ,
361360 batch_size = config ["batch_size" ],
@@ -366,7 +365,7 @@ def evaluate(self) -> dict[str, Any]:
366365 # Load dataset and compute top-1 / top-5
367366 dataset = DeiTTinyEvaluator .__load_dataset (self .__validation_set_path )
368367 top1 , top5 = GenericModelEvaluator .evaluate_topk (
369- self .int8_model , dataset , self .__batch_size , topk = 5
368+ self .quant_model , dataset , self .__batch_size , topk = 5
370369 )
371370 output = super ().evaluate ()
372371 output ["metrics" ]["accuracy" ] = {"top-1" : top1 , "top-5" : top5 }
@@ -380,14 +379,14 @@ def __init__(
380379 self ,
381380 model_name : str ,
382381 fp32_model : Module ,
383- int8_model : Module ,
382+ quant_model : Module ,
384383 example_input : Tuple [torch .Tensor ],
385384 tosa_output_path : str | None ,
386385 batch_size : int ,
387386 validation_dataset_path : str ,
388387 ) -> None :
389388 super ().__init__ (
390- model_name , fp32_model , int8_model , example_input , tosa_output_path
389+ model_name , fp32_model , quant_model , example_input , tosa_output_path
391390 )
392391 self .__batch_size = batch_size
393392 self .__validation_set_path = validation_dataset_path
@@ -406,15 +405,15 @@ def from_config(
406405 cls ,
407406 model_name : str ,
408407 fp32_model : Module ,
409- int8_model : Module ,
408+ quant_model : Module ,
410409 example_input : Tuple [torch .Tensor ],
411410 tosa_output_path : str | None ,
412411 config : dict [str , Any ],
413412 ) -> "ResNet18Evaluator" :
414413 return cls (
415414 model_name ,
416415 fp32_model ,
417- int8_model ,
416+ quant_model ,
418417 example_input ,
419418 tosa_output_path ,
420419 batch_size = config ["batch_size" ],
@@ -424,7 +423,7 @@ def from_config(
424423 def evaluate (self ) -> dict [str , Any ]:
425424 dataset = ResNet18Evaluator .__load_dataset (self .__validation_set_path )
426425 top1 , top5 = GenericModelEvaluator .evaluate_topk (
427- self .int8_model , dataset , self .__batch_size , topk = 5
426+ self .quant_model , dataset , self .__batch_size , topk = 5
428427 )
429428 output = super ().evaluate ()
430429 output ["metrics" ]["accuracy" ] = {"top-1" : top1 , "top-5" : top5 }
@@ -463,8 +462,9 @@ def evaluator_calibration_data(
463462def evaluate_model (
464463 model_name : str ,
465464 intermediates : str ,
465+ target : str ,
466466 model_fp32 : torch .nn .Module ,
467- model_int8 : torch .nn .Module ,
467+ model_quant : torch .nn .Module ,
468468 example_inputs : Tuple [torch .Tensor ],
469469 evaluator_name : str ,
470470 evaluator_config : str | None ,
@@ -486,7 +486,7 @@ def evaluate_model(
486486 init_evaluator = factory (
487487 model_name ,
488488 model_fp32 ,
489- model_int8 ,
489+ model_quant ,
490490 example_inputs ,
491491 str (tosa_paths [0 ]),
492492 config ,
@@ -497,11 +497,11 @@ def evaluate_model(
497497 )
498498 else :
499499 init_evaluator = evaluator (
500- model_name , model_fp32 , model_int8 , example_inputs , str (tosa_paths [0 ])
500+ model_name , model_fp32 , model_quant , example_inputs , str (tosa_paths [0 ])
501501 )
502502
503503 quant_metrics = init_evaluator .evaluate ()
504- output_json_path = intermediates_path / " quant_metrics.json"
504+ output_json_path = intermediates_path / f" { target } - quant_metrics.json"
505505
506506 with output_json_path .open ("w" ) as json_file :
507507 json .dump (quant_metrics , json_file )
0 commit comments