3
3
# This source code is licensed under the BSD-style license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
5
6
+ from dataclasses import dataclass
7
+ from typing import Callable
8
+
6
9
import torch
7
10
8
11
from executorch import exir
29
32
from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
30
33
31
34
32
- def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor ]]):
35
+ @dataclass
36
+ class ModelInputSpec :
37
+ shape : tuple [int , ...]
38
+ dtype : torch .dtype = torch .float32
39
+
40
+
41
+ def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor , ...]]):
33
42
quantizer = NeutronQuantizer ()
34
43
35
44
m = prepare_pt2e (model , quantizer )
@@ -40,35 +49,52 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
40
49
return m
41
50
42
51
43
- def get_random_float_data (input_shapes : tuple [int ] | list [tuple [int ]]):
44
- # TODO: Replace with something more robust.
45
- return (
46
- (torch .randn (input_shapes ),)
47
- if type (input_shapes ) is tuple
48
- else tuple (torch .randn (input_shape ) for input_shape in input_shapes )
49
- )
52
+ def get_random_calibration_inputs (
53
+ input_spec : tuple [ModelInputSpec , ...]
54
+ ) -> list [tuple [torch .Tensor , ...]]:
55
+ return [
56
+ tuple ([torch .randn (spec .shape , dtype = spec .dtype ) for spec in input_spec ])
57
+ for _ in range (4 )
58
+ ]
59
+
60
+
61
+ def to_model_input_spec (
62
+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]]
63
+ ) -> tuple [ModelInputSpec , ...]:
64
+
65
+ if isinstance (input_spec , tuple ) and all (
66
+ isinstance (spec , ModelInputSpec ) for spec in input_spec
67
+ ):
68
+ return input_spec
69
+
70
+ elif isinstance (input_spec , tuple ) and all (
71
+ isinstance (spec , int ) for spec in input_spec
72
+ ):
73
+ return (ModelInputSpec (input_spec ),)
74
+
75
+ elif isinstance (input_spec , list ) and all (
76
+ isinstance (input_shape , tuple ) for input_shape in input_spec
77
+ ):
78
+ return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
79
+ else :
80
+ raise TypeError (f"Unsupported type { type (input_spec )} " )
50
81
51
82
52
83
def to_quantized_edge_program (
53
84
model : torch .nn .Module ,
54
- input_shapes : tuple [int , ...] | list [tuple [int , ...]],
85
+ input_spec : tuple [ ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
55
86
operators_not_to_delegate : list [str ] = None ,
87
+ get_calibration_inputs_fn : Callable [
88
+ [tuple [ModelInputSpec , ...]], list [tuple [torch .Tensor , ...]]
89
+ ] = get_random_calibration_inputs ,
56
90
target = "imxrt700" ,
57
91
neutron_converter_flavor = "SDK_25_03" ,
58
92
remove_quant_io_ops = False ,
59
93
custom_delegation_options = CustomDelegationOptions (), # noqa B008
60
94
) -> EdgeProgramManager :
61
- if isinstance (input_shapes , list ):
62
- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
63
- "For multiple inputs, provide" " list[tuple[int]]."
64
- )
95
+ calibration_inputs = get_calibration_inputs_fn (to_model_input_spec (input_spec ))
65
96
66
- calibration_inputs = [get_random_float_data (input_shapes ) for _ in range (4 )]
67
- example_input = (
68
- (torch .ones (input_shapes ),)
69
- if type (input_shapes ) is tuple
70
- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
71
- )
97
+ example_input = calibration_inputs [0 ]
72
98
73
99
exir_program_aten = torch .export .export_for_training (
74
100
model , example_input , strict = True
@@ -104,27 +130,22 @@ def to_quantized_edge_program(
104
130
105
131
106
132
def to_quantized_executorch_program (
107
- model : torch .nn .Module , input_shapes : tuple [int , ...] | list [tuple [int , ...]]
133
+ model : torch .nn .Module ,
134
+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
108
135
) -> ExecutorchProgramManager :
109
- edge_program_manager = to_quantized_edge_program (model , input_shapes )
136
+ edge_program_manager = to_quantized_edge_program (model , input_spec )
110
137
111
138
return edge_program_manager .to_executorch (
112
139
config = ExecutorchBackendConfig (extract_delegate_segments = False )
113
140
)
114
141
115
142
116
143
def to_edge_program (
117
- model : nn .Module , input_shapes : tuple [int , ...] | list [tuple [int , ...]]
144
+ model : nn .Module ,
145
+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
118
146
) -> EdgeProgramManager :
119
- if isinstance (input_shapes , list ):
120
- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
121
- "For multiple inputs, provide" " list[tuple[int]]."
122
- )
147
+ calibration_inputs = get_random_calibration_inputs (to_model_input_spec (input_spec ))
123
148
124
- example_input = (
125
- (torch .ones (input_shapes ),)
126
- if type (input_shapes ) is tuple
127
- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
128
- )
149
+ example_input = calibration_inputs [0 ]
129
150
exir_program = torch .export .export (model , example_input )
130
151
return exir .to_edge (exir_program )
0 commit comments