Skip to content

Commit deaf37f

Browse files
NXP backend: Refactor executorch_pipeline.py, allow different sources for calibration inputs (#13494)
### Summary Introduces `ModelInputSpec` for specification of model inputs. Arbitrary data type of model input can now be specified. Moves creation of calibration data outside main `to_quantized_edge_program()` function, thus enables using different data source than random data creation. ### Test plan All tests that use `executorch_pipeline.py` calls - almost all of backend tests.
1 parent 2a58471 commit deaf37f

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

backends/nxp/tests/executorch_pipeline.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from dataclasses import dataclass
7+
from typing import Callable
8+
69
import torch
710

811
from executorch import exir
@@ -29,7 +32,13 @@
2932
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3033

3134

32-
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
35+
@dataclass
36+
class ModelInputSpec:
37+
shape: tuple[int, ...]
38+
dtype: torch.dtype = torch.float32
39+
40+
41+
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]):
3342
quantizer = NeutronQuantizer()
3443

3544
m = prepare_pt2e(model, quantizer)
@@ -40,35 +49,52 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
4049
return m
4150

4251

43-
def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
44-
# TODO: Replace with something more robust.
45-
return (
46-
(torch.randn(input_shapes),)
47-
if type(input_shapes) is tuple
48-
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
49-
)
52+
def get_random_calibration_inputs(
53+
input_spec: tuple[ModelInputSpec, ...]
54+
) -> list[tuple[torch.Tensor, ...]]:
55+
return [
56+
tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec])
57+
for _ in range(4)
58+
]
59+
60+
61+
def to_model_input_spec(
62+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
63+
) -> tuple[ModelInputSpec, ...]:
64+
65+
if isinstance(input_spec, tuple) and all(
66+
isinstance(spec, ModelInputSpec) for spec in input_spec
67+
):
68+
return input_spec
69+
70+
elif isinstance(input_spec, tuple) and all(
71+
isinstance(spec, int) for spec in input_spec
72+
):
73+
return (ModelInputSpec(input_spec),)
74+
75+
elif isinstance(input_spec, list) and all(
76+
isinstance(input_shape, tuple) for input_shape in input_spec
77+
):
78+
return tuple([ModelInputSpec(spec) for spec in input_spec])
79+
else:
80+
raise TypeError(f"Unsupported type {type(input_spec)}")
5081

5182

5283
def to_quantized_edge_program(
5384
model: torch.nn.Module,
54-
input_shapes: tuple[int, ...] | list[tuple[int, ...]],
85+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
5586
operators_not_to_delegate: list[str] = None,
87+
get_calibration_inputs_fn: Callable[
88+
[tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]]
89+
] = get_random_calibration_inputs,
5690
target="imxrt700",
5791
neutron_converter_flavor="SDK_25_03",
5892
remove_quant_io_ops=False,
5993
custom_delegation_options=CustomDelegationOptions(), # noqa B008
6094
) -> EdgeProgramManager:
61-
if isinstance(input_shapes, list):
62-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
63-
"For multiple inputs, provide" " list[tuple[int]]."
64-
)
95+
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))
6596

66-
calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)]
67-
example_input = (
68-
(torch.ones(input_shapes),)
69-
if type(input_shapes) is tuple
70-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
71-
)
97+
example_input = calibration_inputs[0]
7298

7399
exir_program_aten = torch.export.export_for_training(
74100
model, example_input, strict=True
@@ -104,27 +130,22 @@ def to_quantized_edge_program(
104130

105131

106132
def to_quantized_executorch_program(
107-
model: torch.nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
133+
model: torch.nn.Module,
134+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
108135
) -> ExecutorchProgramManager:
109-
edge_program_manager = to_quantized_edge_program(model, input_shapes)
136+
edge_program_manager = to_quantized_edge_program(model, input_spec)
110137

111138
return edge_program_manager.to_executorch(
112139
config=ExecutorchBackendConfig(extract_delegate_segments=False)
113140
)
114141

115142

116143
def to_edge_program(
117-
model: nn.Module, input_shapes: tuple[int, ...] | list[tuple[int, ...]]
144+
model: nn.Module,
145+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
118146
) -> EdgeProgramManager:
119-
if isinstance(input_shapes, list):
120-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
121-
"For multiple inputs, provide" " list[tuple[int]]."
122-
)
147+
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec))
123148

124-
example_input = (
125-
(torch.ones(input_shapes),)
126-
if type(input_shapes) is tuple
127-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
128-
)
149+
example_input = calibration_inputs[0]
129150
exir_program = torch.export.export(model, example_input)
130151
return exir.to_edge(exir_program)

0 commit comments

Comments
 (0)