Skip to content

Commit 5055796

Browse files
NXP backend: Move calibration inputs creation outside to_quantized_edge_program()
Enables using different sources for data source
1 parent 3067e98 commit 5055796

File tree

1 file changed

+53
-31
lines changed

1 file changed

+53
-31
lines changed

backends/nxp/tests/executorch_pipeline.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from dataclasses import dataclass
7+
from typing import Callable
8+
9+
import numpy as np
610
import torch
711

812
from executorch import exir
@@ -23,7 +27,13 @@
2327
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2428

2529

26-
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
30+
@dataclass
31+
class ModelInputSpec:
32+
shape: tuple[int, ...]
33+
dtype: torch.dtype = torch.float32
34+
35+
36+
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]):
2737
quantizer = NeutronQuantizer()
2838

2939
m = prepare_pt2e(model, quantizer)
@@ -34,34 +44,51 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
3444
return m
3545

3646

37-
def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
38-
# TODO: Replace with something more robust.
39-
return (
40-
(torch.randn(input_shapes),)
41-
if type(input_shapes) is tuple
42-
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
43-
)
47+
def get_random_calibration_inputs(
48+
input_spec: tuple[ModelInputSpec, ...]
49+
) -> list[tuple[torch.Tensor, ...]]:
50+
return [
51+
tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec])
52+
for _ in range(4)
53+
]
54+
55+
56+
def to_model_input_spec(
57+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
58+
) -> tuple[ModelInputSpec, ...]:
59+
60+
if isinstance(input_spec, tuple) and all(
61+
isinstance(spec, ModelInputSpec) for spec in input_spec
62+
):
63+
return input_spec
64+
65+
elif isinstance(input_spec, tuple) and all(
66+
isinstance(spec, int) for spec in input_spec
67+
):
68+
return (ModelInputSpec(input_spec),)
69+
70+
elif isinstance(input_spec, list) and all(
71+
isinstance(input_shape, tuple) for input_shape in input_spec
72+
):
73+
return tuple([ModelInputSpec(spec) for spec in input_spec])
74+
else:
75+
raise TypeError(f"Unsupported type {type(input_spec)}")
4476

4577

4678
def to_quantized_edge_program(
4779
model: torch.nn.Module,
48-
input_shapes: tuple[int] | list[tuple[int]],
80+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
4981
operators_not_to_delegate: list[str] = None,
82+
get_calibration_inputs_fn: Callable[
83+
[tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]]
84+
] = get_random_calibration_inputs,
5085
target="imxrt700",
5186
neutron_converter_flavor="SDK_25_03",
5287
remove_quant_io_ops=False,
5388
) -> EdgeProgramManager:
54-
if isinstance(input_shapes, list):
55-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
56-
"For multiple inputs, provide" " list[tuple[int]]."
57-
)
89+
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))
5890

59-
calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)]
60-
example_input = (
61-
(torch.ones(input_shapes),)
62-
if type(input_shapes) is tuple
63-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
64-
)
91+
example_input = calibration_inputs[0]
6592

6693
exir_program_aten = torch.export.export_for_training(
6794
model, example_input, strict=True
@@ -94,27 +121,22 @@ def to_quantized_edge_program(
94121

95122

96123
def to_quantized_executorch_program(
97-
model: torch.nn.Module, input_shapes: tuple[int] | list[tuple[int]]
124+
model: torch.nn.Module,
125+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
98126
) -> ExecutorchProgramManager:
99-
edge_program_manager = to_quantized_edge_program(model, input_shapes)
127+
edge_program_manager = to_quantized_edge_program(model, input_spec)
100128

101129
return edge_program_manager.to_executorch(
102130
config=ExecutorchBackendConfig(extract_delegate_segments=False)
103131
)
104132

105133

106134
def to_edge_program(
107-
model: nn.Module, input_shapes: tuple[int] | list[tuple[int]]
135+
model: nn.Module,
136+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
108137
) -> EdgeProgramManager:
109-
if isinstance(input_shapes, list):
110-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
111-
"For multiple inputs, provide" " list[tuple[int]]."
112-
)
138+
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec))
113139

114-
example_input = (
115-
(torch.ones(input_shapes),)
116-
if type(input_shapes) is tuple
117-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
118-
)
140+
example_input = calibration_inputs[0]
119141
exir_program = torch.export.export(model, example_input)
120142
return exir.to_edge(exir_program)

0 commit comments

Comments
 (0)