Skip to content

Commit 3784f06

Browse files
tom-armfreddan80
authored andcommitted
Add MobileNetV2Evaluator to Arm Backend
* Add MobileNetV2Evaluator which can measure top-1 and top-5 accuracy of MobileNet V2 * Extend aot_arm_compiler.py to accept more Evaluators * Improve handling of calibration datasets aot_arm_compiler.py * Improve typing Signed-off-by: Tom Allsop <[email protected]> Change-Id: I50c52e7d97dc38da4ae8f09b258e61c472415ca1
1 parent 8526d0a commit 3784f06

File tree

3 files changed

+235
-37
lines changed

3 files changed

+235
-37
lines changed

backends/arm/test/misc/test_model_evaluator.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import random
87
import tempfile
98
import unittest
109

1110
import torch
1211
from executorch.backends.arm.util.arm_model_evaluator import GenericModelEvaluator
1312

14-
random.seed(0)
15-
1613
# Create an input that is hard to compress
17-
COMPRESSION_RATIO_TEST = bytearray(random.getrandbits(8) for _ in range(1000000))
14+
COMPRESSION_RATIO_TEST = torch.rand([1024, 1024])
1815

1916

2017
def mocked_model_1(input: torch.Tensor) -> torch.Tensor:
@@ -47,20 +44,16 @@ def test_get_model_error(self):
4744

4845
def test_get_compression_ratio(self):
4946
with tempfile.NamedTemporaryFile(delete=True) as temp_bin:
50-
temp_bin.write(COMPRESSION_RATIO_TEST)
51-
52-
# As the size of the file is quite small we need to call flush()
53-
temp_bin.flush()
54-
temp_bin_name = temp_bin.name
47+
torch.save(COMPRESSION_RATIO_TEST, temp_bin)
5548

5649
example_input = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
5750
evaluator = GenericModelEvaluator(
5851
"dummy_model",
5952
mocked_model_1,
6053
mocked_model_2,
6154
example_input,
62-
temp_bin_name,
55+
temp_bin.name,
6356
)
6457

6558
ratio = evaluator.get_compression_ratio()
66-
self.assertAlmostEqual(ratio, 1.0, places=2)
59+
self.assertAlmostEqual(ratio, 1.1, places=1)

backends/arm/util/arm_model_evaluator.py

Lines changed: 106 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,25 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import os
9+
import random
810
import tempfile
911
import zipfile
12+
1013
from collections import defaultdict
11-
from typing import Optional, Tuple
14+
from pathlib import Path
15+
from typing import Any, Optional, Tuple
1216

1317
import torch
18+
from torch.nn.modules import Module
19+
from torch.utils.data import DataLoader
20+
from torchvision import datasets, transforms
21+
22+
23+
# Logger for outputting progress for longer running evaluation
24+
logger = logging.getLogger(__name__)
25+
logger.setLevel(logging.INFO)
1426

1527

1628
def flatten_args(args) -> tuple | list:
@@ -28,6 +40,8 @@ def flatten_args(args) -> tuple | list:
2840

2941

3042
class GenericModelEvaluator:
43+
REQUIRES_CONFIG = False
44+
3145
def __init__(
3246
self,
3347
model_name: str,
@@ -90,7 +104,7 @@ def get_compression_ratio(self) -> float:
90104

91105
return compression_ratio
92106

93-
def evaluate(self) -> dict[any]:
107+
def evaluate(self) -> dict[Any]:
94108
model_error_dict = self.get_model_error()
95109

96110
output_metrics = {"name": self.model_name, "metrics": dict(model_error_dict)}
@@ -103,3 +117,93 @@ def evaluate(self) -> dict[any]:
103117
] = self.get_compression_ratio()
104118

105119
return output_metrics
120+
121+
122+
class MobileNetV2Evaluator(GenericModelEvaluator):
123+
REQUIRES_CONFIG = True
124+
125+
def __init__(
126+
self,
127+
model_name: str,
128+
fp32_model: Module,
129+
int8_model: Module,
130+
example_input: Tuple[torch.Tensor],
131+
tosa_output_path: str | None,
132+
batch_size: int,
133+
validation_dataset_path: str,
134+
) -> None:
135+
super().__init__(
136+
model_name, fp32_model, int8_model, example_input, tosa_output_path
137+
)
138+
139+
self.__batch_size = batch_size
140+
self.__validation_set_path = validation_dataset_path
141+
142+
@staticmethod
143+
def __load_dataset(directory: str) -> datasets.ImageFolder:
144+
directory_path = Path(directory)
145+
if not directory_path.exists():
146+
raise FileNotFoundError(f"Directory: {directory} does not exist.")
147+
148+
transform = transforms.Compose(
149+
[
150+
transforms.Resize(256),
151+
transforms.CenterCrop(224),
152+
transforms.ToTensor(),
153+
transforms.Normalize(
154+
mean=[0.484, 0.454, 0.403], std=[0.225, 0.220, 0.220]
155+
),
156+
]
157+
)
158+
return datasets.ImageFolder(directory_path, transform=transform)
159+
160+
@staticmethod
161+
def get_calibrator(training_dataset_path: str) -> DataLoader:
162+
dataset = MobileNetV2Evaluator.__load_dataset(training_dataset_path)
163+
rand_indices = random.sample(range(len(dataset)), k=1000)
164+
165+
# Return a subset of the dataset to be used for calibration
166+
return torch.utils.data.DataLoader(
167+
torch.utils.data.Subset(dataset, rand_indices),
168+
batch_size=1,
169+
shuffle=False,
170+
)
171+
172+
def __evaluate_mobilenet(self) -> Tuple[float, float]:
173+
dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path)
174+
loaded_dataset = DataLoader(
175+
dataset,
176+
batch_size=self.__batch_size,
177+
shuffle=False,
178+
)
179+
180+
top1_correct = 0
181+
top5_correct = 0
182+
183+
for i, (image, target) in enumerate(loaded_dataset):
184+
prediction = self.int8_model(image)
185+
top1_prediction = torch.topk(prediction, k=1, dim=1).indices
186+
top5_prediction = torch.topk(prediction, k=5, dim=1).indices
187+
188+
top1_correct += (top1_prediction == target.view(-1, 1)).sum().item()
189+
top5_correct += (top5_prediction == target.view(-1, 1)).sum().item()
190+
191+
logger.info("Iteration: {}".format((i + 1) * self.__batch_size))
192+
logger.info(
193+
"Top 1: {}".format(top1_correct / ((i + 1) * self.__batch_size))
194+
)
195+
logger.info(
196+
"Top 5: {}".format(top5_correct / ((i + 1) * self.__batch_size))
197+
)
198+
199+
top1_accuracy = top1_correct / len(dataset)
200+
top5_accuracy = top5_correct / len(dataset)
201+
202+
return top1_accuracy, top5_accuracy
203+
204+
def evaluate(self) -> dict[str, Any]:
205+
top1_correct, top5_correct = self.__evaluate_mobilenet()
206+
output = super().evaluate()
207+
208+
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
209+
return output

0 commit comments

Comments
 (0)