diff --git a/backends/arm/test/misc/test_model_evaluator.py b/backends/arm/test/misc/test_model_evaluator.py index 1f23f176bc7..d88bdfb0000 100644 --- a/backends/arm/test/misc/test_model_evaluator.py +++ b/backends/arm/test/misc/test_model_evaluator.py @@ -4,17 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import random import tempfile import unittest import torch from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator -random.seed(0) - # Create an input that is hard to compress -COMPRESSION_RATIO_TEST = bytearray(random.getrandbits(8) for _ in range(1000000)) +COMPRESSION_RATIO_TEST = torch.rand([1024, 1024]) def mocked_model_1(input: torch.Tensor) -> torch.Tensor: @@ -47,11 +44,7 @@ def test_get_model_error(self): def test_get_compression_ratio(self): with tempfile.NamedTemporaryFile(delete=True) as temp_bin: - temp_bin.write(COMPRESSION_RATIO_TEST) - - # As the size of the file is quite small we need to call flush() - temp_bin.flush() - temp_bin_name = temp_bin.name + torch.save(COMPRESSION_RATIO_TEST, temp_bin) example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) evaluator = GenericModelEvaluator( @@ -59,8 +52,8 @@ def test_get_compression_ratio(self): mocked_model_1, mocked_model_2, example_input, - temp_bin_name, + temp_bin.name, ) ratio = evaluator.get_compression_ratio() - self.assertAlmostEqual(ratio, 1.0, places=2) + self.assertAlmostEqual(ratio, 1.1, places=1) diff --git a/backends/arm/util/arm_model_evaluator.py b/backends/arm/util/arm_model_evaluator.py index 99e142abd36..f8aeab25ba1 100644 --- a/backends/arm/util/arm_model_evaluator.py +++ b/backends/arm/util/arm_model_evaluator.py @@ -4,13 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging import os +import random import tempfile import zipfile + from collections import defaultdict -from typing import Optional, Tuple +from pathlib import Path +from typing import Any, Optional, Tuple import torch +from torch.nn.modules import Module +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +# Logger for outputting progress for longer running evaluation +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def flatten_args(args) -> tuple | list: @@ -28,6 +40,8 @@ def flatten_args(args) -> tuple | list: class GenericModelEvaluator: + REQUIRES_CONFIG = False + def __init__( self, model_name: str, @@ -90,7 +104,7 @@ def get_compression_ratio(self) -> float: return compression_ratio - def evaluate(self) -> dict[any]: + def evaluate(self) -> dict[Any]: model_error_dict = self.get_model_error() output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)} @@ -103,3 +117,93 @@ def evaluate(self) -> dict[any]: ] = self.get_compression_ratio() return output_metrics + + +class MobileNetV2Evaluator(GenericModelEvaluator): + REQUIRES_CONFIG = True + + def __init__( + self, + model_name: str, + fp32_model: Module, + int8_model: Module, + example_input: Tuple[torch.Tensor], + tosa_output_path: str | None, + batch_size: int, + validation_dataset_path: str, + ) -> None: + super().__init__( + model_name, fp32_model, int8_model, example_input, tosa_output_path + ) + + self.__batch_size = batch_size + self.__validation_set_path = validation_dataset_path + + @staticmethod + def __load_dataset(directory: str) -> datasets.ImageFolder: + directory_path = Path(directory) + if not directory_path.exists(): + raise FileNotFoundError(f"Directory: {directory} does not exist.") + + transform = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220] + ), + ] + ) + return datasets.ImageFolder(directory_path, transform=transform) + + @staticmethod + def get_calibrator(training_dataset_path: str) -> DataLoader: + dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path) + rand_indices = random.sample(range(len(dataset)), k=1000) + + # Return a subset of the dataset to be used for calibration + return torch.utils.data.DataLoader( + torch.utils.data.Subset(dataset, rand_indices), + batch_size=1, + shuffle=False, + ) + + def __evaluate_mobilenet(self) -> Tuple[float, float]: + dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path) + loaded_dataset = DataLoader( + dataset, + batch_size=self.__batch_size, + shuffle=False, + ) + + top1_correct = 0 + top5_correct = 0 + + for i, (image, target) in enumerate(loaded_dataset): + prediction = self.int8_model(image) + top1_prediction = torch.topk(prediction, k=1, dim=1).indices + top5_prediction = torch.topk(prediction, k=5, dim=1).indices + + top1_correct += (top1_prediction == target.view(-1, 1)).sum().item() + top5_correct += (top5_prediction == target.view(-1, 1)).sum().item() + + logger.info("Iteration: {}".format((i + 1) * self.__batch_size)) + logger.info( + "Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size)) + ) + logger.info( + "Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size)) + ) + + top1_accuracy = top1_correct / len(dataset) + top5_accuracy = top5_correct / len(dataset) + + return top1_accuracy, top5_accuracy + + def evaluate(self) -> dict[str, Any]: + top1_correct, top5_correct = self.__evaluate_mobilenet() + output = super().evaluate() + + output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct} + return output diff --git a/examples/arm/aot_arm_compiler.py b/examples/arm/aot_arm_compiler.py index f5028713c72..4953f8735e3 100644 --- a/examples/arm/aot_arm_compiler.py +++ b/examples/arm/aot_arm_compiler.py @@ -13,7 +13,7 @@ import os from pathlib import Path -from typing import Optional, Tuple +from typing import Any, Dict, Optional, Tuple import torch from executorch.backends.arm.arm_backend import ArmCompileSpecBuilder @@ -22,8 +22,11 @@ ArmQuantizer, get_symmetric_quantization_config, ) -from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator +from executorch.backends.arm.util.arm_model_evaluator import ( + GenericModelEvaluator, + MobileNetV2Evaluator, +) from executorch.devtools.backend_debug import get_delegation_info from executorch.exir import ( EdgeCompileConfig, @@ -35,6 +38,7 @@ # Quantize model if required using the standard export quantizaion flow. from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.utils.data import DataLoader from ..models import MODEL_NAME_TO_MODEL from ..models.model_factory import EagerModelFactory @@ -43,7 +47,7 @@ logging.basicConfig(level=logging.WARNING, format=FORMAT) -def get_model_and_inputs_from_name(model_name: str): +def get_model_and_inputs_from_name(model_name: str) -> Tuple[torch.nn.Module, Any]: """Given the name of an example pytorch model, return it and example inputs. Raises RuntimeError if there is no example model corresponding to the given name. @@ -81,20 +85,37 @@ def get_model_and_inputs_from_name(model_name: str): return model, example_inputs -def quantize(model, example_inputs): +def quantize( + model: torch.nn.Module, + model_name: str, + example_inputs: Tuple[torch.Tensor], + evaluator_name: str | None, + evaluator_config: Dict[str, Any] | None, +) -> torch.nn.Module: """This is the official recommended flow for quantization in pytorch 2.0 export""" logging.info("Quantizing Model...") logging.debug(f"Original model: {model}") quantizer = ArmQuantizer() + # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel operator_config = get_symmetric_quantization_config(is_per_channel=False) quantizer.set_global(operator_config) m = prepare_pt2e(model, quantizer) - # calibration - m(*example_inputs) + + dataset = get_calibration_data( + model_name, example_inputs, evaluator_name, evaluator_config + ) + + # The dataset could be a tuple of tensors or a DataLoader + # These two cases need to be accounted for + if isinstance(dataset, DataLoader): + for sample, _ in dataset: + m(sample) + else: + m(*dataset) + m = convert_pt2e(m) logging.debug(f"Quantized model: {m}") - # make sure we can export to flat buffer return m @@ -158,7 +179,23 @@ def forward(self, x): "softmax": SoftmaxModule, } -evaluators = {} +calibration_data = { + "add": (torch.randn(1, 5),), + "add2": ( + torch.randn(1, 5), + torch.randn(1, 5), + ), + "add3": ( + torch.randn(32, 5), + torch.randn(32, 5), + ), + "softmax": (torch.randn(32, 2, 2),), +} + +evaluators = { + "generic": GenericModelEvaluator, + "mv2": MobileNetV2Evaluator, +} targets = [ "ethos-u55-32", @@ -174,6 +211,39 @@ def forward(self, x): ] +def get_calibration_data( + model_name: str, + example_inputs: Tuple[torch.Tensor], + evaluator_name: str | None, + evaluator_config: str | None, +): + # Firstly, if the model is being evaluated, take the evaluators calibration function if it has one + if evaluator_name is not None: + evaluator = evaluators[evaluator_name] + + if hasattr(evaluator, "get_calibrator"): + assert evaluator_config is not None + + config_path = Path(evaluator_config) + with config_path.open() as f: + config = json.load(f) + + if evaluator_name == "mv2": + return evaluator.get_calibrator( + training_dataset_path=config["training_dataset_path"] + ) + else: + raise RuntimeError(f"Unknown evaluator: {evaluator_name}") + + # If the model is in the calibration_data dictionary, get the data from there + # This is used for the simple model examples provided + if model_name in calibration_data: + return calibration_data[model_name] + + # As a last resort, fallback to the scripts previous behavior and return the example inputs + return example_inputs + + def get_compile_spec( target: str, intermediates: Optional[str] = None ) -> ArmCompileSpecBuilder: @@ -215,29 +285,44 @@ def get_compile_spec( return spec_builder.build() -def get_evaluator(model_name: str) -> GenericModelEvaluator: - if model_name not in evaluators: - return GenericModelEvaluator - else: - return evaluators[model_name] - - def evaluate_model( model_name: str, intermediates: str, model_fp32: torch.nn.Module, model_int8: torch.nn.Module, example_inputs: Tuple[torch.Tensor], -): - evaluator = get_evaluator(model_name) + evaluator_name: str, + evaluator_config: str | None, +) -> None: + evaluator = evaluators[evaluator_name] # Get the path of the TOSA flatbuffer that is dumped intermediates_path = Path(intermediates) tosa_paths = list(intermediates_path.glob("*.tosa")) - init_evaluator = evaluator( - model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) - ) + if evaluator.REQUIRES_CONFIG: + assert evaluator_config is not None + + config_path = Path(evaluator_config) + with config_path.open() as f: + config = json.load(f) + + if evaluator_name == "mv2": + init_evaluator = evaluator( + model_name, + model_fp32, + model_int8, + example_inputs, + str(tosa_paths[0]), + config["batch_size"], + config["validation_dataset_path"], + ) + else: + raise RuntimeError(f"Unknown evaluator {evaluator_name}") + else: + init_evaluator = evaluator( + model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0]) + ) quant_metrics = init_evaluator.evaluate() output_json_path = intermediates_path / "quant_metrics.json" @@ -289,11 +374,19 @@ def get_args(): parser.add_argument( "-e", "--evaluate", - action="store_true", required=False, - default=False, + nargs="?", + const="generic", + choices=["generic", "mv2"], help="Flag for running evaluation of the model.", ) + parser.add_argument( + "-c", + "--evaluate_config", + required=False, + default=None, + help="Provide path to evaluator config, if it is required.", + ) parser.add_argument( "-q", "--quantize", @@ -375,7 +468,9 @@ def get_args(): # Quantize if required model_int8 = None if args.quantize: - model = quantize(model, example_inputs) + model = quantize( + model, args.model_name, example_inputs, args.evaluate, args.evaluate_config + ) model_int8 = model # Wrap quantized model back into an exported_program exported_program = torch.export.export_for_training(model, example_inputs) @@ -433,5 +528,11 @@ def get_args(): if args.evaluate: evaluate_model( - args.model_name, args.intermediates, model_fp32, model_int8, example_inputs + args.model_name, + args.intermediates, + model_fp32, + model_int8, + example_inputs, + args.evaluate, + args.evaluate_config, )