Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 243 additions & 63 deletions backends/arm/util/arm_model_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unintentional delete?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is intentional since file is anyway modified and there is only an Arm copyright.

#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -30,7 +29,139 @@
logger.setLevel(logging.INFO)


# ImageNet 224x224 transforms (Resize->CenterCrop->ToTensor->Normalize)
# If future models require different preprocessing, extend this helper accordingly.
def _get_imagenet_224_transforms():
"""Return standard ImageNet 224x224 preprocessing transforms."""
return transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]),
]
)


def _build_calibration_loader(
dataset: datasets.ImageFolder, max_items: int
) -> DataLoader:
"""Return a DataLoader over a deterministic, shuffled subset of size <= max_items.

Shuffles with seed: ARM_EVAL_CALIB_SEED (int) or default 1337; then selects first k and
sorts indices to keep enumeration order stable while content depends on seed.
"""
k = min(max_items, len(dataset))
seed_env = os.getenv("ARM_EVAL_CALIB_SEED")
default_seed = 1337
if seed_env is not None:
try:
seed = int(seed_env)
except ValueError:
logger.warning(
"ARM_EVAL_CALIB_SEED is not an int (%s); using default seed %d",
seed_env,
default_seed,
)
seed = default_seed
else:
seed = default_seed
rng = random.Random(seed)
indices = list(range(len(dataset)))
rng.shuffle(indices)
selected = sorted(indices[:k])
return torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, selected), batch_size=1, shuffle=False
)


def _load_imagenet_folder(directory: str) -> datasets.ImageFolder:
"""Shared helper to load an ImageNet-layout folder.

Raises FileNotFoundError for a missing directory early to aid debugging.
"""
directory_path = Path(directory)
if not directory_path.exists():
raise FileNotFoundError(f"Directory: {directory} does not exist.")
transform = _get_imagenet_224_transforms()
return datasets.ImageFolder(directory_path, transform=transform)


class GenericModelEvaluator:
"""Base evaluator computing quantization error metrics and optional compression ratio.

Subclasses can extend: provide calibration (get_calibrator) and override evaluate()
to add domain specific metrics (e.g. top-1 / top-5 accuracy).
"""

@staticmethod
def evaluate_topk(
model: Module,
dataset: datasets.ImageFolder,
batch_size: int,
topk: int = 5,
log_every: int = 50,
) -> Tuple[float, float]:
"""Evaluate model top-1 / top-k accuracy.

Args:
model: Torch module (should be in eval() mode prior to call).
dataset: ImageFolder style dataset.
batch_size: Batch size for evaluation.
topk: Maximum k for accuracy (default 5).
log_every: Log running accuracy every N batches.
Returns:
(top1_accuracy, topk_accuracy)
"""
# Some exported / quantized models (torchao PT2E) disallow direct eval()/train().
# Try to switch to eval mode, but degrade gracefully if unsupported.
try:
model.eval()
except NotImplementedError:
# Attempt to enable train/eval overrides if torchao helper is present.
try:
from torchao.quantization.pt2e.utils import ( # type: ignore
allow_exported_model_train_eval,
)

allow_exported_model_train_eval(model)
try:
model.eval()
except Exception:
logger.debug(
"Model eval still not supported after allow_exported_model_train_eval; proceeding without explicit eval()."
)
except Exception:
logger.debug(
"Model eval() unsupported and torchao allow_exported_model_train_eval not available; proceeding."
)
loaded_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
top1_correct = 0
topk_correct = 0
total = 0
with torch.inference_mode(): # disable autograd + some backend optimizations
for i, (image, target) in enumerate(loaded_dataset):
prediction = model(image)
topk_indices = torch.topk(prediction, k=topk, dim=1).indices
# target reshaped for broadcasting
target_view = target.view(-1, 1)
top1_correct += (topk_indices[:, :1] == target_view).sum().item()
topk_correct += (topk_indices == target_view).sum().item()
batch_sz = image.size(0)
total += batch_sz
if (i + 1) % log_every == 0 or total == len(dataset):
logger.info(
"Eval progress: %d / %d top1=%.4f top%d=%.4f",
total,
len(dataset),
top1_correct / total,
topk,
topk_correct / total,
)
top1_accuracy = top1_correct / len(dataset)
topk_accuracy = topk_correct / len(dataset)
return top1_accuracy, topk_accuracy

REQUIRES_CONFIG = False

def __init__(
Expand All @@ -53,12 +184,13 @@ def __init__(
self.tosa_output_path = ""

def get_model_error(self) -> defaultdict:
"""
Returns a dict containing the following metrics between the outputs of the FP32 and INT8 model:
- Maximum error
- Maximum absolute error
- Maximum percentage error
- Mean absolute error
"""Return per-output quantization error statistics.

Metrics (lists per output tensor):
max_error
max_absolute_error
max_percentage_error (safe-divided; zero fp32 elements -> 0%)
mean_absolute_error
"""
fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input))
int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input))
Expand All @@ -67,7 +199,12 @@ def get_model_error(self) -> defaultdict:

for fp32_output, int8_output in zip(fp32_outputs, int8_outputs):
difference = fp32_output - int8_output
percentage_error = torch.div(difference, fp32_output) * 100
# Avoid divide by zero: elements where fp32 == 0 produce 0% contribution
percentage_error = torch.where(
fp32_output != 0,
difference / fp32_output * 100,
torch.zeros_like(difference),
)
model_error_dict["max_error"].append(torch.max(difference).item())
model_error_dict["max_absolute_error"].append(
torch.max(torch.abs(difference)).item()
Expand Down Expand Up @@ -132,77 +269,116 @@ def __init__(

@staticmethod
def __load_dataset(directory: str) -> datasets.ImageFolder:
directory_path = Path(directory)
if not directory_path.exists():
raise FileNotFoundError(f"Directory: {directory} does not exist.")

transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]
),
]
)
return datasets.ImageFolder(directory_path, transform=transform)
return _load_imagenet_folder(directory)

@staticmethod
def get_calibrator(training_dataset_path: str) -> DataLoader:
dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path)
rand_indices = random.sample(range(len(dataset)), k=1000)
return _build_calibration_loader(dataset, 1000)

# Return a subset of the dataset to be used for calibration
return torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, rand_indices),
batch_size=1,
shuffle=False,
@classmethod
def from_config(
cls,
model_name: str,
fp32_model: Module,
int8_model: Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: str | None,
config: dict[str, Any],
) -> "MobileNetV2Evaluator":
"""Factory constructing evaluator from a config dict.

Expected keys: batch_size, validation_dataset_path
"""
return cls(
model_name,
fp32_model,
int8_model,
example_input,
tosa_output_path,
batch_size=config["batch_size"],
validation_dataset_path=config["validation_dataset_path"],
)

def __evaluate_mobilenet(self) -> Tuple[float, float]:
def evaluate(self) -> dict[str, Any]:
# Load dataset and compute top-1 / top-5
dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path)
loaded_dataset = DataLoader(
dataset,
batch_size=self.__batch_size,
shuffle=False,
top1_correct, top5_correct = GenericModelEvaluator.evaluate_topk(
self.int8_model, dataset, self.__batch_size, topk=5
)
output = super().evaluate()

top1_correct = 0
top5_correct = 0
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
return output

for i, (image, target) in enumerate(loaded_dataset):
prediction = self.int8_model(image)
top1_prediction = torch.topk(prediction, k=1, dim=1).indices
top5_prediction = torch.topk(prediction, k=5, dim=1).indices

top1_correct += (top1_prediction == target.view(-1, 1)).sum().item()
top5_correct += (top5_prediction == target.view(-1, 1)).sum().item()
class DeiTTinyEvaluator(GenericModelEvaluator):
REQUIRES_CONFIG = True

logger.info("Iteration: {}".format((i + 1) * self.__batch_size))
logger.info(
"Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size))
)
logger.info(
"Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size))
)
def __init__(
self,
model_name: str,
fp32_model: Module,
int8_model: Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: str | None,
batch_size: int,
validation_dataset_path: str,
) -> None:
super().__init__(
model_name, fp32_model, int8_model, example_input, tosa_output_path
)
self.__batch_size = batch_size
self.__validation_set_path = validation_dataset_path

top1_accuracy = top1_correct / len(dataset)
top5_accuracy = top5_correct / len(dataset)
@staticmethod
def __load_dataset(directory: str) -> datasets.ImageFolder:
return _load_imagenet_folder(directory)

return top1_accuracy, top5_accuracy
@staticmethod
def get_calibrator(training_dataset_path: str) -> DataLoader:
dataset = DeiTTinyEvaluator.__load_dataset(training_dataset_path)
return _build_calibration_loader(dataset, 1000)

@classmethod
def from_config(
cls,
model_name: str,
fp32_model: Module,
int8_model: Module,
example_input: Tuple[torch.Tensor],
tosa_output_path: str | None,
config: dict[str, Any],
) -> "DeiTTinyEvaluator":
"""Factory constructing evaluator from a config dict.

Expected keys: batch_size, validation_dataset_path
"""
return cls(
model_name,
fp32_model,
int8_model,
example_input,
tosa_output_path,
batch_size=config["batch_size"],
validation_dataset_path=config["validation_dataset_path"],
)

def evaluate(self) -> dict[str, Any]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it better to refactor this and MobileNetV2Evaluator to share much of the code?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion, will do.

top1_correct, top5_correct = self.__evaluate_mobilenet()
# Load dataset and compute top-1 / top-5
dataset = DeiTTinyEvaluator.__load_dataset(self.__validation_set_path)
top1, top5 = GenericModelEvaluator.evaluate_topk(
self.int8_model, dataset, self.__batch_size, topk=5
)
output = super().evaluate()

output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
return output


evaluators: dict[str, type[GenericModelEvaluator]] = {
"generic": GenericModelEvaluator,
"mv2": MobileNetV2Evaluator,
"deit_tiny": DeiTTinyEvaluator,
}


Expand All @@ -223,6 +399,10 @@ def evaluator_calibration_data(
return evaluator.get_calibrator(
training_dataset_path=config["training_dataset_path"]
)
if evaluator is DeiTTinyEvaluator:
return evaluator.get_calibrator(
training_dataset_path=config["training_dataset_path"]
)
else:
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")

Expand All @@ -238,30 +418,30 @@ def evaluate_model(
) -> None:
evaluator = evaluators[evaluator_name]

# Get the path of the TOSA flatbuffer that is dumped
intermediates_path = Path(intermediates)
tosa_paths = list(intermediates_path.glob("*.tosa"))

if evaluator.REQUIRES_CONFIG:
assert evaluator_config is not None

config_path = Path(evaluator_config)
with config_path.open() as f:
config = json.load(f)

if evaluator == MobileNetV2Evaluator:
mv2_evaluator = cast(type[MobileNetV2Evaluator], evaluator)
init_evaluator: GenericModelEvaluator = mv2_evaluator(
# Prefer a subclass provided from_config if available.
if hasattr(evaluator, "from_config"):
factory = cast(Any, evaluator.from_config) # type: ignore[attr-defined]
init_evaluator = factory(
model_name,
model_fp32,
model_int8,
example_inputs,
str(tosa_paths[0]),
batch_size=config["batch_size"],
validation_dataset_path=config["validation_dataset_path"],
config,
)
else:
raise RuntimeError(f"Unknown evaluator {evaluator_name}")
raise RuntimeError(
f"Evaluator {evaluator_name} requires config but does not implement from_config()"
)
else:
init_evaluator = evaluator(
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
Expand Down
Loading
Loading