Skip to content

Commit e06ffc2

Browse files
NXP backend: Move calibration inputs creation outside to_quantized_edge_program()
Enables using different sources for data source
1 parent 3067e98 commit e06ffc2

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

backends/nxp/tests/executorch_pipeline.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from dataclasses import dataclass
7+
from typing import Callable
8+
69
import torch
710

811
from executorch import exir
@@ -23,7 +26,13 @@
2326
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2427

2528

26-
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
29+
@dataclass
30+
class ModelInputSpec:
31+
shape: tuple[int, ...]
32+
dtype: torch.dtype = torch.float32
33+
34+
35+
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor, ...]]):
2736
quantizer = NeutronQuantizer()
2837

2938
m = prepare_pt2e(model, quantizer)
@@ -34,34 +43,51 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
3443
return m
3544

3645

37-
def get_random_float_data(input_shapes: tuple[int] | list[tuple[int]]):
38-
# TODO: Replace with something more robust.
39-
return (
40-
(torch.randn(input_shapes),)
41-
if type(input_shapes) is tuple
42-
else tuple(torch.randn(input_shape) for input_shape in input_shapes)
43-
)
46+
def get_random_calibration_inputs(
47+
input_spec: tuple[ModelInputSpec, ...]
48+
) -> list[tuple[torch.Tensor, ...]]:
49+
return [
50+
tuple([torch.randn(spec.shape, dtype=spec.dtype) for spec in input_spec])
51+
for _ in range(4)
52+
]
53+
54+
55+
def to_model_input_spec(
56+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
57+
) -> tuple[ModelInputSpec, ...]:
58+
59+
if isinstance(input_spec, tuple) and all(
60+
isinstance(spec, ModelInputSpec) for spec in input_spec
61+
):
62+
return input_spec
63+
64+
elif isinstance(input_spec, tuple) and all(
65+
isinstance(spec, int) for spec in input_spec
66+
):
67+
return (ModelInputSpec(input_spec),)
68+
69+
elif isinstance(input_spec, list) and all(
70+
isinstance(input_shape, tuple) for input_shape in input_spec
71+
):
72+
return tuple([ModelInputSpec(spec) for spec in input_spec])
73+
else:
74+
raise TypeError(f"Unsupported type {type(input_spec)}")
4475

4576

4677
def to_quantized_edge_program(
4778
model: torch.nn.Module,
48-
input_shapes: tuple[int] | list[tuple[int]],
79+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
4980
operators_not_to_delegate: list[str] = None,
81+
get_calibration_inputs_fn: Callable[
82+
[tuple[ModelInputSpec, ...]], list[tuple[torch.Tensor, ...]]
83+
] = get_random_calibration_inputs,
5084
target="imxrt700",
5185
neutron_converter_flavor="SDK_25_03",
5286
remove_quant_io_ops=False,
5387
) -> EdgeProgramManager:
54-
if isinstance(input_shapes, list):
55-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
56-
"For multiple inputs, provide" " list[tuple[int]]."
57-
)
88+
calibration_inputs = get_calibration_inputs_fn(to_model_input_spec(input_spec))
5889

59-
calibration_inputs = [get_random_float_data(input_shapes) for _ in range(4)]
60-
example_input = (
61-
(torch.ones(input_shapes),)
62-
if type(input_shapes) is tuple
63-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
64-
)
90+
example_input = calibration_inputs[0]
6591

6692
exir_program_aten = torch.export.export_for_training(
6793
model, example_input, strict=True
@@ -94,27 +120,22 @@ def to_quantized_edge_program(
94120

95121

96122
def to_quantized_executorch_program(
97-
model: torch.nn.Module, input_shapes: tuple[int] | list[tuple[int]]
123+
model: torch.nn.Module,
124+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
98125
) -> ExecutorchProgramManager:
99-
edge_program_manager = to_quantized_edge_program(model, input_shapes)
126+
edge_program_manager = to_quantized_edge_program(model, input_spec)
100127

101128
return edge_program_manager.to_executorch(
102129
config=ExecutorchBackendConfig(extract_delegate_segments=False)
103130
)
104131

105132

106133
def to_edge_program(
107-
model: nn.Module, input_shapes: tuple[int] | list[tuple[int]]
134+
model: nn.Module,
135+
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]],
108136
) -> EdgeProgramManager:
109-
if isinstance(input_shapes, list):
110-
assert all(isinstance(input_shape, tuple) for input_shape in input_shapes), (
111-
"For multiple inputs, provide" " list[tuple[int]]."
112-
)
137+
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec))
113138

114-
example_input = (
115-
(torch.ones(input_shapes),)
116-
if type(input_shapes) is tuple
117-
else tuple(torch.ones(input_shape) for input_shape in input_shapes)
118-
)
139+
example_input = calibration_inputs[0]
119140
exir_program = torch.export.export(model, example_input)
120141
return exir.to_edge(exir_program)

0 commit comments

Comments
 (0)