Skip to content

Commit b13a77a

Browse files
authored
Arm backend: Small refactoring of aot_arm_compiler (#14583)
* Move evaluation logic out of aot_arm_compiler * Remove example models It doesn't make sense to have example models in the aot_arm_compiler script anymore, they are too simple, and there are other ways to do examples. Add and softmax can be replaced with example/models tests, add2 and add3 are deprecated. q-models can't be removed yet since the new testing is not in place but they should be as soon as it is. cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 --------- Signed-off-by: Erik Lundell <[email protected]>
1 parent a747e4d commit b13a77a

File tree

3 files changed

+79
-149
lines changed

3 files changed

+79
-149
lines changed

backends/arm/util/arm_model_evaluator.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import json
910
import logging
1011
import os
1112
import random
@@ -14,7 +15,7 @@
1415

1516
from collections import defaultdict
1617
from pathlib import Path
17-
from typing import Any, Optional, Tuple
18+
from typing import Any, cast, Optional, Tuple
1819

1920
import torch
2021
from torch.nn.modules import Module
@@ -197,3 +198,77 @@ def evaluate(self) -> dict[str, Any]:
197198

198199
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
199200
return output
201+
202+
203+
evaluators: dict[str, type[GenericModelEvaluator]] = {
204+
"generic": GenericModelEvaluator,
205+
"mv2": MobileNetV2Evaluator,
206+
}
207+
208+
209+
def evaluator_calibration_data(
210+
evaluator_name: str,
211+
evaluator_config: str | None,
212+
):
213+
evaluator = evaluators[evaluator_name]
214+
215+
if hasattr(evaluator, "get_calibrator"):
216+
assert evaluator_config is not None
217+
218+
config_path = Path(evaluator_config)
219+
with config_path.open() as f:
220+
config = json.load(f)
221+
222+
if evaluator is MobileNetV2Evaluator:
223+
return evaluator.get_calibrator(
224+
training_dataset_path=config["training_dataset_path"]
225+
)
226+
else:
227+
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
228+
229+
230+
def evaluate_model(
231+
model_name: str,
232+
intermediates: str,
233+
model_fp32: torch.nn.Module,
234+
model_int8: torch.nn.Module,
235+
example_inputs: Tuple[torch.Tensor],
236+
evaluator_name: str,
237+
evaluator_config: str | None,
238+
) -> None:
239+
evaluator = evaluators[evaluator_name]
240+
241+
# Get the path of the TOSA flatbuffer that is dumped
242+
intermediates_path = Path(intermediates)
243+
tosa_paths = list(intermediates_path.glob("*.tosa"))
244+
245+
if evaluator.REQUIRES_CONFIG:
246+
assert evaluator_config is not None
247+
248+
config_path = Path(evaluator_config)
249+
with config_path.open() as f:
250+
config = json.load(f)
251+
252+
if evaluator == MobileNetV2Evaluator:
253+
mv2_evaluator = cast(type[MobileNetV2Evaluator], evaluator)
254+
init_evaluator: GenericModelEvaluator = mv2_evaluator(
255+
model_name,
256+
model_fp32,
257+
model_int8,
258+
example_inputs,
259+
str(tosa_paths[0]),
260+
batch_size=config["batch_size"],
261+
validation_dataset_path=config["validation_dataset_path"],
262+
)
263+
else:
264+
raise RuntimeError(f"Unknown evaluator {evaluator_name}")
265+
else:
266+
init_evaluator = evaluator(
267+
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
268+
)
269+
270+
quant_metrics = init_evaluator.evaluate()
271+
output_json_path = intermediates_path / "quant_metrics.json"
272+
273+
with output_json_path.open("w") as json_file:
274+
json.dump(quant_metrics, json_file)

examples/arm/aot_arm_compiler.py

Lines changed: 3 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import argparse
1111
import copy
12-
import json
1312
import logging
1413
import os
1514

@@ -31,8 +30,8 @@
3130
from executorch.backends.arm.tosa.partitioner import TOSAPartitioner
3231

3332
from executorch.backends.arm.util.arm_model_evaluator import (
34-
GenericModelEvaluator,
35-
MobileNetV2Evaluator,
33+
evaluate_model,
34+
evaluator_calibration_data,
3635
)
3736

3837
from executorch.backends.arm.vgf import VgfCompileSpec, VgfPartitioner
@@ -188,46 +187,6 @@ def quantize(
188187
return m
189188

190189

191-
# Simple example models
192-
class AddModule(torch.nn.Module):
193-
def __init__(self):
194-
super().__init__()
195-
196-
def forward(self, x):
197-
return x + x
198-
199-
example_input = (torch.ones(5, dtype=torch.int32),)
200-
can_delegate = True
201-
202-
203-
class AddModule2(torch.nn.Module):
204-
def __init__(self):
205-
super().__init__()
206-
207-
def forward(self, x, y):
208-
return x + y
209-
210-
example_input = (
211-
torch.ones(5, dtype=torch.int32),
212-
torch.ones(5, dtype=torch.int32),
213-
)
214-
can_delegate = True
215-
216-
217-
class AddModule3(torch.nn.Module):
218-
def __init__(self):
219-
super().__init__()
220-
221-
def forward(self, x, y):
222-
return (x + y, x + x)
223-
224-
example_input = (
225-
torch.ones(5, dtype=torch.int32),
226-
torch.ones(5, dtype=torch.int32),
227-
)
228-
can_delegate = True
229-
230-
231190
class QuantAddTest(torch.nn.Module):
232191
def __init__(self):
233192
super().__init__()
@@ -276,27 +235,6 @@ def forward(self, w, x, y, z):
276235
can_delegate = True # when quantized
277236

278237

279-
class SoftmaxModule(torch.nn.Module):
280-
def __init__(self):
281-
super().__init__()
282-
self.softmax = torch.nn.Softmax(dim=0)
283-
284-
def forward(self, x):
285-
z = self.softmax(x)
286-
return z
287-
288-
example_input = (torch.ones(2, 2),)
289-
can_delegate = True
290-
291-
292-
class MultipleOutputsModule(torch.nn.Module):
293-
def forward(self, x: torch.Tensor, y: torch.Tensor):
294-
return (x * y, x.sum(dim=-1, keepdim=True))
295-
296-
example_input = (torch.randn(10, 4, 5), torch.randn(10, 4, 5))
297-
can_delegate = True
298-
299-
300238
class QuantLinearTest(torch.nn.Module):
301239
def __init__(self):
302240
super().__init__()
@@ -311,29 +249,15 @@ def forward(self, x):
311249

312250

313251
models = {
314-
"add": AddModule,
315-
"add2": AddModule2,
316-
"add3": AddModule3,
317252
"qadd": QuantAddTest,
318253
"qadd2": QuantAddTest2,
319254
"qops": QuantOpTest,
320-
"softmax": SoftmaxModule,
321-
"MultipleOutputsModule": MultipleOutputsModule,
322255
# TODO: Remove this from here, once we have dedicated MCU test pipeline ready. This is an interim solution.
323256
# See https://github.com/pytorch/executorch/discussions/13944
324257
"qlinear": QuantLinearTest,
325258
}
326259

327260
calibration_data = {
328-
"add": (torch.randn(1, 5),),
329-
"add2": (
330-
torch.randn(1, 5),
331-
torch.randn(1, 5),
332-
),
333-
"add3": (
334-
torch.randn(32, 5),
335-
torch.randn(32, 5),
336-
),
337261
"qadd": (torch.randn(32, 2, 1),),
338262
"qadd2": (
339263
torch.randn(32, 2, 1),
@@ -345,13 +269,6 @@ def forward(self, x):
345269
torch.randn(32, 2, 1) * -0.000001,
346270
torch.randn(32, 2, 1) * 1000,
347271
),
348-
"softmax": (torch.randn(32, 2, 2),),
349-
"qlinear": (torch.randn(37, 61),),
350-
}
351-
352-
evaluators = {
353-
"generic": GenericModelEvaluator,
354-
"mv2": MobileNetV2Evaluator,
355272
}
356273

357274
targets = [
@@ -378,21 +295,7 @@ def get_calibration_data(
378295
):
379296
# Firstly, if the model is being evaluated, take the evaluators calibration function if it has one
380297
if evaluator_name is not None:
381-
evaluator = evaluators[evaluator_name]
382-
383-
if hasattr(evaluator, "get_calibrator"):
384-
assert evaluator_config is not None
385-
386-
config_path = Path(evaluator_config)
387-
with config_path.open() as f:
388-
config = json.load(f)
389-
390-
if evaluator_name == "mv2":
391-
return evaluator.get_calibrator(
392-
training_dataset_path=config["training_dataset_path"]
393-
)
394-
else:
395-
raise RuntimeError(f"Unknown evaluator: {evaluator_name}")
298+
return evaluator_calibration_data(evaluator_name, evaluator_config)
396299

397300
# If the model is in the calibration_data dictionary, get the data from there
398301
# This is used for the simple model examples provided
@@ -446,52 +349,6 @@ def get_compile_spec(
446349
return compile_spec
447350

448351

449-
def evaluate_model(
450-
model_name: str,
451-
intermediates: str,
452-
model_fp32: torch.nn.Module,
453-
model_int8: torch.nn.Module,
454-
example_inputs: Tuple[torch.Tensor],
455-
evaluator_name: str,
456-
evaluator_config: str | None,
457-
) -> None:
458-
evaluator = evaluators[evaluator_name]
459-
460-
# Get the path of the TOSA flatbuffer that is dumped
461-
intermediates_path = Path(intermediates)
462-
tosa_paths = list(intermediates_path.glob("*.tosa"))
463-
464-
if evaluator.REQUIRES_CONFIG:
465-
assert evaluator_config is not None
466-
467-
config_path = Path(evaluator_config)
468-
with config_path.open() as f:
469-
config = json.load(f)
470-
471-
if evaluator_name == "mv2":
472-
init_evaluator = evaluator(
473-
model_name,
474-
model_fp32,
475-
model_int8,
476-
example_inputs,
477-
str(tosa_paths[0]),
478-
config["batch_size"],
479-
config["validation_dataset_path"],
480-
)
481-
else:
482-
raise RuntimeError(f"Unknown evaluator {evaluator_name}")
483-
else:
484-
init_evaluator = evaluator(
485-
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
486-
)
487-
488-
quant_metrics = init_evaluator.evaluate()
489-
output_json_path = intermediates_path / "quant_metrics.json"
490-
491-
with output_json_path.open("w") as json_file:
492-
json.dump(quant_metrics, json_file)
493-
494-
495352
def dump_delegation_info(edge, intermediate_files_folder: Optional[str] = None):
496353
graph_module = edge.exported_program().graph_module
497354
delegation_info = get_delegation_info(graph_module)

examples/arm/run.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,6 @@ if [[ -z "$model_name" ]]; then
225225
test_model=(
226226
"softmax" # 0
227227
"add" # 1
228-
"add3" # 2
229228
"qadd" # 3
230229
"qadd2" # 4
231230
"qops" # 5
@@ -234,7 +233,6 @@ if [[ -z "$model_name" ]]; then
234233
model_compiler_flags=(
235234
"" # 0 softmax
236235
"--delegate" # 1 add
237-
"--delegate" # 2 add3
238236
"--delegate --quantize" # 3 qadd
239237
"--delegate --quantize" # 4 qadd2
240238
"--delegate --quantize" # 5 qops

0 commit comments

Comments
 (0)