@@ -374,10 +374,69 @@ def evaluate(self) -> dict[str, Any]:
374374 return output
375375
376376
377+ class ResNet18Evaluator (GenericModelEvaluator ):
378+ REQUIRES_CONFIG = True
379+
380+ def __init__ (
381+ self ,
382+ model_name : str ,
383+ fp32_model : Module ,
384+ int8_model : Module ,
385+ example_input : Tuple [torch .Tensor ],
386+ tosa_output_path : str | None ,
387+ batch_size : int ,
388+ validation_dataset_path : str ,
389+ ) -> None :
390+ super ().__init__ (
391+ model_name , fp32_model , int8_model , example_input , tosa_output_path
392+ )
393+ self .__batch_size = batch_size
394+ self .__validation_set_path = validation_dataset_path
395+
396+ @staticmethod
397+ def __load_dataset (directory : str ) -> datasets .ImageFolder :
398+ return _load_imagenet_folder (directory )
399+
400+ @staticmethod
401+ def get_calibrator (training_dataset_path : str ) -> DataLoader :
402+ dataset = ResNet18Evaluator .__load_dataset (training_dataset_path )
403+ return _build_calibration_loader (dataset , 1000 )
404+
405+ @classmethod
406+ def from_config (
407+ cls ,
408+ model_name : str ,
409+ fp32_model : Module ,
410+ int8_model : Module ,
411+ example_input : Tuple [torch .Tensor ],
412+ tosa_output_path : str | None ,
413+ config : dict [str , Any ],
414+ ) -> "ResNet18Evaluator" :
415+ return cls (
416+ model_name ,
417+ fp32_model ,
418+ int8_model ,
419+ example_input ,
420+ tosa_output_path ,
421+ batch_size = config ["batch_size" ],
422+ validation_dataset_path = config ["validation_dataset_path" ],
423+ )
424+
425+ def evaluate (self ) -> dict [str , Any ]:
426+ dataset = ResNet18Evaluator .__load_dataset (self .__validation_set_path )
427+ top1 , top5 = GenericModelEvaluator .evaluate_topk (
428+ self .int8_model , dataset , self .__batch_size , topk = 5
429+ )
430+ output = super ().evaluate ()
431+ output ["metrics" ]["accuracy" ] = {"top-1" : top1 , "top-5" : top5 }
432+ return output
433+
434+
377435evaluators : dict [str , type [GenericModelEvaluator ]] = {
378436 "generic" : GenericModelEvaluator ,
379437 "mv2" : MobileNetV2Evaluator ,
380438 "deit_tiny" : DeiTTinyEvaluator ,
439+ "resnet18" : ResNet18Evaluator ,
381440}
382441
383442
@@ -394,16 +453,12 @@ def evaluator_calibration_data(
394453 with config_path .open () as f :
395454 config = json .load (f )
396455
397- if evaluator is MobileNetV2Evaluator :
398- return evaluator .get_calibrator (
399- training_dataset_path = config ["training_dataset_path" ]
400- )
401- if evaluator is DeiTTinyEvaluator :
402- return evaluator .get_calibrator (
403- training_dataset_path = config ["training_dataset_path" ]
404- )
405- else :
406- raise RuntimeError (f"Unknown evaluator: { evaluator_name } " )
456+ # All current evaluators exposing calibration implement a uniform
457+ # static method signature: get_calibrator(training_dataset_path: str)
458+ # so we can call it generically without enumerating classes.
459+ return evaluator .get_calibrator (
460+ training_dataset_path = config ["training_dataset_path" ]
461+ )
407462
408463
409464def evaluate_model (
0 commit comments