Skip to content

Commit 0601b7f

Browse files
authored
Arm backend: Enable int16x8 quantization on aot_arm_compiler (#15811)
### Summary Adds int16x8 target to aot_arm_compiler to enable accuracy testing in backends/arm/util/arm_model_evaluator.py - renaming int8 -> quant for model name cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Saoirse Stewart <[email protected]>
1 parent da6306f commit 0601b7f

File tree

2 files changed

+61
-41
lines changed

2 files changed

+61
-41
lines changed

backends/arm/util/arm_model_evaluator.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ def __init__(
167167
self,
168168
model_name: str,
169169
fp32_model: torch.nn.Module,
170-
int8_model: torch.nn.Module,
170+
quant_model: torch.nn.Module,
171171
example_input: Tuple[torch.Tensor],
172172
tosa_output_path: Optional[str],
173173
) -> None:
174174
self.model_name = model_name
175175

176176
self.fp32_model = fp32_model
177-
self.int8_model = int8_model
177+
self.quant_model = quant_model
178178
self.example_input = example_input
179179

180180
if tosa_output_path:
@@ -192,12 +192,12 @@ def get_model_error(self) -> defaultdict:
192192
mean_absolute_error
193193
"""
194194
fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input))
195-
int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input))
195+
quant_outputs, _ = tree_flatten(self.quant_model(*self.example_input))
196196

197197
model_error_dict = defaultdict(list)
198198

199-
for fp32_output, int8_output in zip(fp32_outputs, int8_outputs):
200-
difference = fp32_output - int8_output
199+
for fp32_output, quant_output in zip(fp32_outputs, quant_outputs):
200+
difference = fp32_output - quant_output
201201
# Avoid divide by zero: elements where fp32 == 0 produce 0% contribution
202202
percentage_error = torch.where(
203203
fp32_output != 0,
@@ -252,14 +252,14 @@ def __init__(
252252
self,
253253
model_name: str,
254254
fp32_model: Module,
255-
int8_model: Module,
255+
quant_model: Module,
256256
example_input: Tuple[torch.Tensor],
257257
tosa_output_path: str | None,
258258
batch_size: int,
259259
validation_dataset_path: str,
260260
) -> None:
261261
super().__init__(
262-
model_name, fp32_model, int8_model, example_input, tosa_output_path
262+
model_name, fp32_model, quant_model, example_input, tosa_output_path
263263
)
264264

265265
self.__batch_size = batch_size
@@ -279,7 +279,7 @@ def from_config(
279279
cls,
280280
model_name: str,
281281
fp32_model: Module,
282-
int8_model: Module,
282+
quant_model: Module,
283283
example_input: Tuple[torch.Tensor],
284284
tosa_output_path: str | None,
285285
config: dict[str, Any],
@@ -291,7 +291,7 @@ def from_config(
291291
return cls(
292292
model_name,
293293
fp32_model,
294-
int8_model,
294+
quant_model,
295295
example_input,
296296
tosa_output_path,
297297
batch_size=config["batch_size"],
@@ -302,10 +302,9 @@ def evaluate(self) -> dict[str, Any]:
302302
# Load dataset and compute top-1 / top-5
303303
dataset = MobileNetV2Evaluator.__load_dataset(self.__validation_set_path)
304304
top1_correct, top5_correct = GenericModelEvaluator.evaluate_topk(
305-
self.int8_model, dataset, self.__batch_size, topk=5
305+
self.quant_model, dataset, self.__batch_size, topk=5
306306
)
307307
output = super().evaluate()
308-
309308
output["metrics"]["accuracy"] = {"top-1": top1_correct, "top-5": top5_correct}
310309
return output
311310

@@ -317,14 +316,14 @@ def __init__(
317316
self,
318317
model_name: str,
319318
fp32_model: Module,
320-
int8_model: Module,
319+
quant_model: Module,
321320
example_input: Tuple[torch.Tensor],
322321
tosa_output_path: str | None,
323322
batch_size: int,
324323
validation_dataset_path: str,
325324
) -> None:
326325
super().__init__(
327-
model_name, fp32_model, int8_model, example_input, tosa_output_path
326+
model_name, fp32_model, quant_model, example_input, tosa_output_path
328327
)
329328
self.__batch_size = batch_size
330329
self.__validation_set_path = validation_dataset_path
@@ -343,7 +342,7 @@ def from_config(
343342
cls,
344343
model_name: str,
345344
fp32_model: Module,
346-
int8_model: Module,
345+
quant_model: Module,
347346
example_input: Tuple[torch.Tensor],
348347
tosa_output_path: str | None,
349348
config: dict[str, Any],
@@ -355,7 +354,7 @@ def from_config(
355354
return cls(
356355
model_name,
357356
fp32_model,
358-
int8_model,
357+
quant_model,
359358
example_input,
360359
tosa_output_path,
361360
batch_size=config["batch_size"],
@@ -366,7 +365,7 @@ def evaluate(self) -> dict[str, Any]:
366365
# Load dataset and compute top-1 / top-5
367366
dataset = DeiTTinyEvaluator.__load_dataset(self.__validation_set_path)
368367
top1, top5 = GenericModelEvaluator.evaluate_topk(
369-
self.int8_model, dataset, self.__batch_size, topk=5
368+
self.quant_model, dataset, self.__batch_size, topk=5
370369
)
371370
output = super().evaluate()
372371
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
@@ -380,14 +379,14 @@ def __init__(
380379
self,
381380
model_name: str,
382381
fp32_model: Module,
383-
int8_model: Module,
382+
quant_model: Module,
384383
example_input: Tuple[torch.Tensor],
385384
tosa_output_path: str | None,
386385
batch_size: int,
387386
validation_dataset_path: str,
388387
) -> None:
389388
super().__init__(
390-
model_name, fp32_model, int8_model, example_input, tosa_output_path
389+
model_name, fp32_model, quant_model, example_input, tosa_output_path
391390
)
392391
self.__batch_size = batch_size
393392
self.__validation_set_path = validation_dataset_path
@@ -406,15 +405,15 @@ def from_config(
406405
cls,
407406
model_name: str,
408407
fp32_model: Module,
409-
int8_model: Module,
408+
quant_model: Module,
410409
example_input: Tuple[torch.Tensor],
411410
tosa_output_path: str | None,
412411
config: dict[str, Any],
413412
) -> "ResNet18Evaluator":
414413
return cls(
415414
model_name,
416415
fp32_model,
417-
int8_model,
416+
quant_model,
418417
example_input,
419418
tosa_output_path,
420419
batch_size=config["batch_size"],
@@ -424,7 +423,7 @@ def from_config(
424423
def evaluate(self) -> dict[str, Any]:
425424
dataset = ResNet18Evaluator.__load_dataset(self.__validation_set_path)
426425
top1, top5 = GenericModelEvaluator.evaluate_topk(
427-
self.int8_model, dataset, self.__batch_size, topk=5
426+
self.quant_model, dataset, self.__batch_size, topk=5
428427
)
429428
output = super().evaluate()
430429
output["metrics"]["accuracy"] = {"top-1": top1, "top-5": top5}
@@ -463,8 +462,9 @@ def evaluator_calibration_data(
463462
def evaluate_model(
464463
model_name: str,
465464
intermediates: str,
465+
target: str,
466466
model_fp32: torch.nn.Module,
467-
model_int8: torch.nn.Module,
467+
model_quant: torch.nn.Module,
468468
example_inputs: Tuple[torch.Tensor],
469469
evaluator_name: str,
470470
evaluator_config: str | None,
@@ -486,7 +486,7 @@ def evaluate_model(
486486
init_evaluator = factory(
487487
model_name,
488488
model_fp32,
489-
model_int8,
489+
model_quant,
490490
example_inputs,
491491
str(tosa_paths[0]),
492492
config,
@@ -497,11 +497,11 @@ def evaluate_model(
497497
)
498498
else:
499499
init_evaluator = evaluator(
500-
model_name, model_fp32, model_int8, example_inputs, str(tosa_paths[0])
500+
model_name, model_fp32, model_quant, example_inputs, str(tosa_paths[0])
501501
)
502502

503503
quant_metrics = init_evaluator.evaluate()
504-
output_json_path = intermediates_path / "quant_metrics.json"
504+
output_json_path = intermediates_path / f"{target}-quant_metrics.json"
505505

506506
with output_json_path.open("w") as json_file:
507507
json.dump(quant_metrics, json_file)

examples/arm/aot_arm_compiler.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from examples.devtools.scripts.export_bundled_program import save_bundled_program
2020
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
2121
from executorch.backends.arm.ethosu import EthosUCompileSpec
22-
from executorch.backends.arm.quantizer import get_symmetric_quantization_config
22+
from executorch.backends.arm.quantizer import (
23+
get_symmetric_a16w8_quantization_config,
24+
get_symmetric_quantization_config,
25+
)
2326
from executorch.backends.arm.tosa import TosaSpecification
2427
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
2528
from executorch.backends.arm.util._factory import create_partitioner, create_quantizer
@@ -228,6 +231,7 @@ def quantize(
228231
example_inputs: Tuple[torch.Tensor],
229232
evaluator_name: str | None,
230233
evaluator_config: Dict[str, Any] | None,
234+
is_int16x8: bool = False,
231235
) -> GraphModule:
232236
"""This is the official recommended flow for quantization in pytorch 2.0
233237
export.
@@ -238,7 +242,18 @@ def quantize(
238242

239243
quantizer = create_quantizer(compile_specs)
240244

241-
operator_config = get_symmetric_quantization_config()
245+
if is_int16x8:
246+
if compile_specs.tosa_spec.support_extension("int16"):
247+
operator_config = get_symmetric_a16w8_quantization_config(
248+
is_per_channel=True
249+
)
250+
else:
251+
raise ValueError(
252+
f"Context TOSA spec {compile_specs.tosa_spec} doesn't support int16"
253+
)
254+
else:
255+
operator_config = get_symmetric_quantization_config(is_per_channel=True)
256+
242257
quantizer.set_global(operator_config)
243258
m = prepare_pt2e(model, quantizer)
244259

@@ -356,6 +371,7 @@ def forward(self, x):
356371
"vgf",
357372
"TOSA-1.0+INT",
358373
"TOSA-1.0+FP",
374+
"TOSA-1.0+INT+int16",
359375
]
360376

361377

@@ -681,20 +697,23 @@ def quantize_model(
681697
example_inputs: Tuple[torch.Tensor],
682698
compile_spec,
683699
) -> Tuple[GraphModule, ExportedProgram]:
684-
model_int8 = quantize(
700+
701+
is_int16x8 = True if args.target == "TOSA-1.0+INT+int16" else False
702+
model_quant = quantize(
685703
model,
686704
args.model_name,
687705
compile_spec,
688706
example_inputs,
689707
args.evaluate,
690708
args.evaluate_config,
709+
is_int16x8,
691710
)
692711
# Wrap quantized model back into an exported_program
693712
exported_program = torch.export.export(
694-
model_int8, example_inputs, strict=args.strict_export
713+
model_quant, example_inputs, strict=args.strict_export
695714
)
696715

697-
return model_int8, exported_program
716+
return model_quant, exported_program
698717

699718

700719
def to_edge_TOSA_delegate(
@@ -715,9 +734,9 @@ def to_edge_TOSA_delegate(
715734
args.enable_debug_mode,
716735
)
717736

718-
model_int8 = None
737+
model_quant = None
719738
if args.quantize:
720-
model_int8, exported_program = quantize_model(
739+
model_quant, exported_program = quantize_model(
721740
args, model, example_inputs, compile_spec
722741
)
723742

@@ -731,7 +750,7 @@ def to_edge_TOSA_delegate(
731750
),
732751
)
733752

734-
return model_int8, edge
753+
return model_quant, edge
735754

736755

737756
def to_edge_no_delegate(
@@ -740,7 +759,7 @@ def to_edge_no_delegate(
740759
model: GraphModule,
741760
example_inputs: Tuple[torch.Tensor],
742761
):
743-
model_int8 = None
762+
model_quant = None
744763
if args.quantize:
745764
# As we can target multiple output encodings, one must
746765
# be specified.
@@ -756,7 +775,7 @@ def to_edge_no_delegate(
756775
model, exported_program = quantize_model(
757776
args, model, example_inputs, compile_spec
758777
)
759-
model_int8 = model
778+
model_quant = model
760779

761780
edge = to_edge_transform_and_lower(
762781
exported_program,
@@ -765,7 +784,7 @@ def to_edge_no_delegate(
765784
),
766785
)
767786

768-
return model_int8, edge
787+
return model_quant, edge
769788

770789

771790
def transform_for_cortex_m_backend(edge_program_manager, args):
@@ -818,13 +837,13 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
818837
)
819838

820839
# Quantize if required
821-
model_int8 = None
840+
model_quant = None
822841
if args.delegate:
823-
model_int8, edge = to_edge_TOSA_delegate(
842+
model_quant, edge = to_edge_TOSA_delegate(
824843
exported_program, args, model, example_inputs
825844
)
826845
else:
827-
model_int8, edge = to_edge_no_delegate(
846+
model_quant, edge = to_edge_no_delegate(
828847
exported_program, args, model, example_inputs
829848
)
830849

@@ -884,7 +903,7 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
884903

885904
if args.bundleio:
886905
# Realize the quantization impact on numerics when generating reference output
887-
reference_model = original_model if not model_int8 else model_int8
906+
reference_model = original_model if not model_quant else model_quant
888907
save_bpte_program(exec_prog, reference_model, output_file_name)
889908
print(f"Bundle PTE file saved as {output_file_name}")
890909
else:
@@ -895,8 +914,9 @@ def transform_for_cortex_m_backend(edge_program_manager, args):
895914
evaluate_model(
896915
args.model_name,
897916
args.intermediates,
917+
args.target,
898918
model_fp32,
899-
model_int8,
919+
model_quant,
900920
example_inputs,
901921
args.evaluate,
902922
args.evaluate_config,

0 commit comments

Comments
 (0)