Skip to content

Commit 7395999

Browse files
authored
Arm backend: add DeiTTiny evaluator and deterministic shuffled calibration subsets (pytorch#14579)
Change-Id: I7f61120772906ae0fec5d1f2b9cfcc0aa2c2c7af ### Summary Add DeiTTiny evaluator for model accuracy evaluation. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 Signed-off-by: Tirui Wu <[email protected]>
1 parent c979158 commit 7395999

File tree

2 files changed

+244
-64
lines changed

2 files changed

+244
-64
lines changed

backends/arm/util/arm_model_evaluator.py

Lines changed: 243 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -30,7 +29,139 @@
3029
logger.setLevel(logging.INFO)
3130

3231

32+
# ImageNet 224x224 transforms (Resize->CenterCrop->ToTensor->Normalize)
33+
# If future models require different preprocessing, extend this helper accordingly.
34+
def _get_imagenet_224_transforms():
35+
"""Return standard ImageNet 224x224 preprocessing transforms."""
36+
return transforms.Compose(
37+
[
38+
transforms.Resize(256),
39+
transforms.CenterCrop(224),
40+
transforms.ToTensor(),
41+
transforms.Normalize(mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]),
42+
]
43+
)
44+
45+
46+
def _build_calibration_loader(
47+
dataset: datasets.ImageFolder, max_items: int
48+
) -> DataLoader:
49+
"""Return a DataLoader over a deterministic, shuffled subset of size <= max_items.
50+
51+
Shuffles with seed: ARM_EVAL_CALIB_SEED (int) or default 1337; then selects first k and
52+
sorts indices to keep enumeration order stable while content depends on seed.
53+
"""
54+
k = min(max_items, len(dataset))
55+
seed_env = os.getenv("ARM_EVAL_CALIB_SEED")
56+
default_seed = 1337
57+
if seed_env is not None:
58+
try:
59+
seed = int(seed_env)
60+
except ValueError:
61+
logger.warning(
62+
"ARM_EVAL_CALIB_SEED is not an int (%s); using default seed %d",
63+
seed_env,
64+
default_seed,
65+
)
66+
seed = default_seed
67+
else:
68+
seed = default_seed
69+
rng = random.Random(seed)
70+
indices = list(range(len(dataset)))
71+
rng.shuffle(indices)
72+
selected = sorted(indices[:k])
73+
return torch.utils.data.DataLoader(
74+
torch.utils.data.Subset(dataset, selected), batch_size=1, shuffle=False
75+
)
76+
77+
78+
def _load_imagenet_folder(directory: str) -> datasets.ImageFolder:
79+
"""Shared helper to load an ImageNet-layout folder.
80+
81+
Raises FileNotFoundError for a missing directory early to aid debugging.
82+
"""
83+
directory_path = Path(directory)
84+
if not directory_path.exists():
85+
raise FileNotFoundError(f"Directory: {directory} does not exist.")
86+
transform = _get_imagenet_224_transforms()
87+
return datasets.ImageFolder(directory_path, transform=transform)
88+
89+
3390
class GenericModelEvaluator:
91+
"""Base evaluator computing quantization error metrics and optional compression ratio.
92+
93+
Subclasses can extend: provide calibration (get_calibrator) and override evaluate()
94+
to add domain specific metrics (e.g. top-1 / top-5 accuracy).
95+
"""
96+
97+
@staticmethod
98+
def evaluate_topk(
99+
model: Module,
100+
dataset: datasets.ImageFolder,
101+
batch_size: int,
102+
topk: int = 5,
103+
log_every: int = 50,
104+
) -> Tuple[float, float]:
105+
"""Evaluate model top-1 / top-k accuracy.
106+
107+
Args:
108+
model: Torch module (should be in eval() mode prior to call).
109+
dataset: ImageFolder style dataset.
110+
batch_size: Batch size for evaluation.
111+
topk: Maximum k for accuracy (default 5).
112+
log_every: Log running accuracy every N batches.
113+
Returns:
114+
(top1_accuracy, topk_accuracy)
115+
"""
116+
# Some exported / quantized models (torchao PT2E) disallow direct eval()/train().
117+
# Try to switch to eval mode, but degrade gracefully if unsupported.
118+
try:
119+
model.eval()
120+
except NotImplementedError:
121+
# Attempt to enable train/eval overrides if torchao helper is present.
122+
try:
123+
from torchao.quantization.pt2e.utils import ( # type: ignore
124+
allow_exported_model_train_eval,
125+
)
126+
127+
allow_exported_model_train_eval(model)
128+
try:
129+
model.eval()
130+
except Exception:
131+
logger.debug(
132+
"Model eval still not supported after allow_exported_model_train_eval; proceeding without explicit eval()."
133+
)
134+
except Exception:
135+
logger.debug(
136+
"Model eval() unsupported and torchao allow_exported_model_train_eval not available; proceeding."
137+
)
138+
loaded_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
139+
top1_correct = 0
140+
topk_correct = 0
141+
total = 0
142+
with torch.inference_mode(): # disable autograd + some backend optimizations
143+
for i, (image, target) in enumerate(loaded_dataset):
144+
prediction = model(image)
145+
topk_indices = torch.topk(prediction, k=topk, dim=1).indices
146+
# target reshaped for broadcasting
147+
target_view = target.view(-1, 1)
148+
top1_correct += (topk_indices[:, :1] == target_view).sum().item()
149+
topk_correct += (topk_indices == target_view).sum().item()
150+
batch_sz = image.size(0)
151+
total += batch_sz
152+
if (i + 1) % log_every == 0 or total == len(dataset):
153+
logger.info(
154+
"Eval progress: %d / %d top1=%.4f top%d=%.4f",
155+
total,
156+
len(dataset),
157+
top1_correct / total,
158+
topk,
159+
topk_correct / total,
160+
)
161+
top1_accuracy = top1_correct / len(dataset)
162+
topk_accuracy = topk_correct / len(dataset)
163+
return top1_accuracy, topk_accuracy
164+
34165
REQUIRES_CONFIG = False
35166

36167
def __init__(
@@ -53,12 +184,13 @@ def __init__(
53184
self.tosa_output_path = ""
54185

55186
def get_model_error(self) -> defaultdict:
56-
"""
57-
Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
58-
- Maximum error
59-
- Maximum absolute error
60-
- Maximum percentage error
61-
- Mean absolute error
187+
"""Return per-output quantization error statistics.
188+
189+
Metrics (lists per output tensor):
190+
max_error
191+
max_absolute_error
192+
max_percentage_error (safe-divided; zero fp32 elements -> 0%)
193+
mean_absolute_error
62194
"""
63195
fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input))
64196
int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input))
@@ -67,7 +199,12 @@ def get_model_error(self) -> defaultdict:
67199

68200
for fp32_output, int8_output in zip(fp32_outputs, int8_outputs):
69201
difference = fp32_output - int8_output
70-
percentage_error = torch.div(difference, fp32_output) * 100
202+
# Avoid divide by zero: elements where fp32 == 0 produce 0% contribution
203+
percentage_error = torch.where(
204+
fp32_output != 0,
205+
difference / fp32_output * 100,
206+
torch.zeros_like(difference),
207+
)
71208
model_error_dict["max_error"].append(torch.max(difference).item())
72209
model_error_dict["max_absolute_error"].append(
73210
torch.max(torch.abs(difference)).item()
@@ -132,77 +269,116 @@ def __init__(
132269

133270
@staticmethod
134271
def __load_dataset(directory: str) -> datasets.ImageFolder:
135-
directory_path = Path(directory)
136-
if not directory_path.exists():
137-
raise FileNotFoundError(f"Directory: {directory} does not exist.")
138-
139-
transform = transforms.Compose(
140-
[
141-
transforms.Resize(256),
142-
transforms.CenterCrop(224),
143-
transforms.ToTensor(),
144-
transforms.Normalize(
145-
mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]
146-
),
147-
]
148-
)
149-
return datasets.ImageFolder(directory_path, transform=transform)
272+
return _load_imagenet_folder(directory)
150273

151274
@staticmethod
152275
def get_calibrator(training_dataset_path: str) -> DataLoader:
153276
dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path)
154-
rand_indices = random.sample(range(len(dataset)), k=1000)
277+
return _build_calibration_loader(dataset, 1000)
155278

156-
# Return a subset of the dataset to be used for calibration
157-
return torch.utils.data.DataLoader(
158-
torch.utils.data.Subset(dataset, rand_indices),
159-
batch_size=1,
160-
shuffle=False,
279+
@classmethod
280+
def from_config(
281+
cls,
282+
model_name: str,
283+
fp32_model: Module,
284+
int8_model: Module,
285+
example_input: Tuple[torch.Tensor],
286+
tosa_output_path: str | None,
287+
config: dict[str, Any],
288+
) -> "MobileNetV2Evaluator":
289+
"""Factory constructing evaluator from a config dict.
290+
291+
Expected keys: batch_size, validation_dataset_path
292+
"""
293+
return cls(
294+
model_name,
295+
fp32_model,
296+
int8_model,
297+
example_input,
298+
tosa_output_path,
299+
batch_size=config["batch_size"],
300+
validation_dataset_path=config["validation_dataset_path"],
161301
)
162302

163-
def __evaluate_mobilenet(self) -> Tuple[float, float]:
303+
def evaluate(self) -> dict[str, Any]:
304+
# Load dataset and compute top-1 / top-5
164305
dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path)
165-
loaded_dataset = DataLoader(
166-
dataset,
167-
batch_size=self.__batch_size,
168-
shuffle=False,
306+
top1_correct, top5_correct = GenericModelEvaluator.evaluate_topk(
307+
self.int8_model, dataset, self.__batch_size, topk=5
169308
)
309+
output = super().evaluate()
170310

171-
top1_correct = 0
172-
top5_correct = 0
311+
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
312+
return output
173313

174-
for i, (image, target) in enumerate(loaded_dataset):
175-
prediction = self.int8_model(image)
176-
top1_prediction = torch.topk(prediction, k=1, dim=1).indices
177-
top5_prediction = torch.topk(prediction, k=5, dim=1).indices
178314

179-
top1_correct += (top1_prediction == target.view(-1, 1)).sum().item()
180-
top5_correct += (top5_prediction == target.view(-1, 1)).sum().item()
315+
class DeiTTinyEvaluator(GenericModelEvaluator):
316+
REQUIRES_CONFIG = True
181317

182-
logger.info("Iteration: {}".format((i + 1) * self.__batch_size))
183-
logger.info(
184-
"Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size))
185-
)
186-
logger.info(
187-
"Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size))
188-
)
318+
def __init__(
319+
self,
320+
model_name: str,
321+
fp32_model: Module,
322+
int8_model: Module,
323+
example_input: Tuple[torch.Tensor],
324+
tosa_output_path: str | None,
325+
batch_size: int,
326+
validation_dataset_path: str,
327+
) -> None:
328+
super().__init__(
329+
model_name, fp32_model, int8_model, example_input, tosa_output_path
330+
)
331+
self.__batch_size = batch_size
332+
self.__validation_set_path = validation_dataset_path
189333

190-
top1_accuracy = top1_correct / len(dataset)
191-
top5_accuracy = top5_correct / len(dataset)
334+
@staticmethod
335+
def __load_dataset(directory: str) -> datasets.ImageFolder:
336+
return _load_imagenet_folder(directory)
192337

193-
return top1_accuracy, top5_accuracy
338+
@staticmethod
339+
def get_calibrator(training_dataset_path: str) -> DataLoader:
340+
dataset = DeiTTinyEvaluator.__load_dataset(training_dataset_path)
341+
return _build_calibration_loader(dataset, 1000)
342+
343+
@classmethod
344+
def from_config(
345+
cls,
346+
model_name: str,
347+
fp32_model: Module,
348+
int8_model: Module,
349+
example_input: Tuple[torch.Tensor],
350+
tosa_output_path: str | None,
351+
config: dict[str, Any],
352+
) -> "DeiTTinyEvaluator":
353+
"""Factory constructing evaluator from a config dict.
354+
355+
Expected keys: batch_size, validation_dataset_path
356+
"""
357+
return cls(
358+
model_name,
359+
fp32_model,
360+
int8_model,
361+
example_input,
362+
tosa_output_path,
363+
batch_size=config["batch_size"],
364+
validation_dataset_path=config["validation_dataset_path"],
365+
)
194366

195367
def evaluate(self) -> dict[str, Any]:
196-
top1_correct, top5_correct = self.__evaluate_mobilenet()
368+
# Load dataset and compute top-1 / top-5
369+
dataset = DeiTTinyEvaluator.__load_dataset(self.__validation_set_path)
370+
top1, top5 = GenericModelEvaluator.evaluate_topk(
371+
self.int8_model, dataset, self.__batch_size, topk=5
372+
)
197373
output = super().evaluate()
198-
199-
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
374+
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
200375
return output
201376

202377

203378
evaluators: dict[str, type[GenericModelEvaluator]] = {
204379
"generic": GenericModelEvaluator,
205380
"mv2": MobileNetV2Evaluator,
381+
"deit_tiny": DeiTTinyEvaluator,
206382
}
207383

208384

@@ -223,6 +399,10 @@ def evaluator_calibration_data(
223399
return evaluator.get_calibrator(
224400
training_dataset_path=config["training_dataset_path"]
225401
)
402+
if evaluator is DeiTTinyEvaluator:
403+
return evaluator.get_calibrator(
404+
training_dataset_path=config["training_dataset_path"]
405+
)
226406
else:
227407
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
228408

@@ -238,30 +418,30 @@ def evaluate_model(
238418
) -> None:
239419
evaluator = evaluators[evaluator_name]
240420

241-
# Get the path of the TOSA flatbuffer that is dumped
242421
intermediates_path = Path(intermediates)
243422
tosa_paths = list(intermediates_path.glob("*.tosa"))
244423

245424
if evaluator.REQUIRES_CONFIG:
246425
assert evaluator_config is not None
247-
248426
config_path = Path(evaluator_config)
249427
with config_path.open() as f:
250428
config = json.load(f)
251429

252-
if evaluator == MobileNetV2Evaluator:
253-
mv2_evaluator = cast(type[MobileNetV2Evaluator], evaluator)
254-
init_evaluator: GenericModelEvaluator = mv2_evaluator(
430+
# Prefer a subclass provided from_config if available.
431+
if hasattr(evaluator, "from_config"):
432+
factory = cast(Any, evaluator.from_config) # type: ignore[attr-defined]
433+
init_evaluator = factory(
255434
model_name,
256435
model_fp32,
257436
model_int8,
258437
example_inputs,
259438
str(tosa_paths[0]),
260-
batch_size=config["batch_size"],
261-
validation_dataset_path=config["validation_dataset_path"],
439+
config,
262440
)
263441
else:
264-
raise RuntimeError(f"Unknown evaluator {evaluator_name}")
442+
raise RuntimeError(
443+
f"Evaluator {evaluator_name} requires config but does not implement from_config()"
444+
)
265445
else:
266446
init_evaluator = evaluator(
267447
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])

0 commit comments

Comments
 (0)