Skip to content

Commit 8374421

Browse files
authored
Arm backend: add ResNet18 evaluator and introduce DeiT Tiny builtin (#15517)
Signed-off-by: Tirui Wu <[email protected]>
1 parent 80a43ac commit 8374421

File tree

5 files changed

+118
-11
lines changed

5 files changed

+118
-11
lines changed

backends/arm/util/arm_model_evaluator.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,10 +374,69 @@ def evaluate(self) -> dict[str, Any]:
374374
return output
375375

376376

377+
class ResNet18Evaluator(GenericModelEvaluator):
378+
REQUIRES_CONFIG = True
379+
380+
def __init__(
381+
self,
382+
model_name: str,
383+
fp32_model: Module,
384+
int8_model: Module,
385+
example_input: Tuple[torch.Tensor],
386+
tosa_output_path: str | None,
387+
batch_size: int,
388+
validation_dataset_path: str,
389+
) -> None:
390+
super().__init__(
391+
model_name, fp32_model, int8_model, example_input, tosa_output_path
392+
)
393+
self.__batch_size = batch_size
394+
self.__validation_set_path = validation_dataset_path
395+
396+
@staticmethod
397+
def __load_dataset(directory: str) -> datasets.ImageFolder:
398+
return _load_imagenet_folder(directory)
399+
400+
@staticmethod
401+
def get_calibrator(training_dataset_path: str) -> DataLoader:
402+
dataset = ResNet18Evaluator.__load_dataset(training_dataset_path)
403+
return _build_calibration_loader(dataset, 1000)
404+
405+
@classmethod
406+
def from_config(
407+
cls,
408+
model_name: str,
409+
fp32_model: Module,
410+
int8_model: Module,
411+
example_input: Tuple[torch.Tensor],
412+
tosa_output_path: str | None,
413+
config: dict[str, Any],
414+
) -> "ResNet18Evaluator":
415+
return cls(
416+
model_name,
417+
fp32_model,
418+
int8_model,
419+
example_input,
420+
tosa_output_path,
421+
batch_size=config["batch_size"],
422+
validation_dataset_path=config["validation_dataset_path"],
423+
)
424+
425+
def evaluate(self) -> dict[str, Any]:
426+
dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path)
427+
top1, top5 = GenericModelEvaluator.evaluate_topk(
428+
self.int8_model, dataset, self.__batch_size, topk=5
429+
)
430+
output = super().evaluate()
431+
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
432+
return output
433+
434+
377435
evaluators: dict[str, type[GenericModelEvaluator]] = {
378436
"generic": GenericModelEvaluator,
379437
"mv2": MobileNetV2Evaluator,
380438
"deit_tiny": DeiTTinyEvaluator,
439+
"resnet18": ResNet18Evaluator,
381440
}
382441

383442

@@ -394,16 +453,12 @@ def evaluator_calibration_data(
394453
with config_path.open() as f:
395454
config = json.load(f)
396455

397-
if evaluator is MobileNetV2Evaluator:
398-
return evaluator.get_calibrator(
399-
training_dataset_path=config["training_dataset_path"]
400-
)
401-
if evaluator is DeiTTinyEvaluator:
402-
return evaluator.get_calibrator(
403-
training_dataset_path=config["training_dataset_path"]
404-
)
405-
else:
406-
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
456+
# All current evaluators exposing calibration implement a uniform
457+
# static method signature: get_calibrator(training_dataset_path: str)
458+
# so we can call it generically without enumerating classes.
459+
return evaluator.get_calibrator(
460+
training_dataset_path=config["training_dataset_path"]
461+
)
407462

408463

409464
def evaluate_model(

examples/arm/aot_arm_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def get_args():
492492
required=False,
493493
nargs="?",
494494
const="generic",
495-
choices=["generic", "mv2", "deit_tiny"],
495+
choices=["generic", "mv2", "deit_tiny", "resnet18"],
496496
help="Flag for running evaluation of the model.",
497497
)
498498
parser.add_argument(

examples/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class Model(str, Enum):
3939
Qwen25 = "qwen2_5_1_5b"
4040
Phi4Mini = "phi_4_mini"
4141
SmolLM2 = "smollm2"
42+
DeiTTiny = "deit_tiny"
4243

4344
def __str__(self) -> str:
4445
return self.value
@@ -87,6 +88,7 @@ def __str__(self) -> str:
8788
str(Model.Qwen25): ("qwen2_5", "Qwen2_5Model"),
8889
str(Model.Phi4Mini): ("phi_4_mini", "Phi4MiniModel"),
8990
str(Model.SmolLM2): ("smollm2", "SmolLM2Model"),
91+
str(Model.DeiTTiny): ("deit_tiny", "DeiTTinyModel"),
9092
}
9193

9294
__all__ = [
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from .model import DeiTTinyModel
7+
8+
__all__ = ["DeiTTinyModel"]

examples/models/deit_tiny/model.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import logging
7+
8+
import torch
9+
from torchvision import transforms
10+
11+
try:
12+
import timm # type: ignore
13+
except ImportError as e: # pragma: no cover
14+
raise RuntimeError(
15+
"timm package is required for builtin 'deit_tiny'. Install timm."
16+
) from e
17+
18+
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
19+
20+
from ..model_base import EagerModelBase
21+
22+
23+
class DeiTTinyModel(EagerModelBase):
24+
25+
def __init__(self): # type: ignore[override]
26+
pass
27+
28+
def get_eager_model(self) -> torch.nn.Module: # type: ignore[override]
29+
logging.info("Loading timm deit_tiny_patch16_224 model")
30+
model = timm.models.deit.deit_tiny_patch16_224(pretrained=False)
31+
model.eval()
32+
logging.info("Loaded timm deit_tiny_patch16_224 model")
33+
return model
34+
35+
def get_example_inputs(self): # type: ignore[override]
36+
normalize = transforms.Normalize(
37+
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD
38+
)
39+
return (normalize(torch.rand((1, 3, 224, 224))),)
40+
41+
42+
__all__ = ["DeiTTinyModel"]

0 commit comments

Comments
 (0)