33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from dataclasses import dataclass
7+ from typing import Callable
8+
69import torch
710
811from executorch import exir
2932from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
3033
3134
32- def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor ]]):
35+ @dataclass
36+ class ModelInputSpec :
37+ shape : tuple [int , ...]
38+ dtype : torch .dtype = torch .float32
39+
40+
41+ def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor , ...]]):
3342 quantizer = NeutronQuantizer ()
3443
3544 m = prepare_pt2e (model , quantizer )
@@ -40,35 +49,52 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
4049 return m
4150
4251
43- def get_random_float_data (input_shapes : tuple [int ] | list [tuple [int ]]):
44- # TODO: Replace with something more robust.
45- return (
46- (torch .randn (input_shapes ),)
47- if type (input_shapes ) is tuple
48- else tuple (torch .randn (input_shape ) for input_shape in input_shapes )
49- )
52+ def get_random_calibration_inputs (
53+ input_spec : tuple [ModelInputSpec , ...]
54+ ) -> list [tuple [torch .Tensor , ...]]:
55+ return [
56+ tuple ([torch .randn (spec .shape , dtype = spec .dtype ) for spec in input_spec ])
57+ for _ in range (4 )
58+ ]
59+
60+
61+ def to_model_input_spec (
62+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]]
63+ ) -> tuple [ModelInputSpec , ...]:
64+
65+ if isinstance (input_spec , tuple ) and all (
66+ isinstance (spec , ModelInputSpec ) for spec in input_spec
67+ ):
68+ return input_spec
69+
70+ elif isinstance (input_spec , tuple ) and all (
71+ isinstance (spec , int ) for spec in input_spec
72+ ):
73+ return (ModelInputSpec (input_spec ),)
74+
75+ elif isinstance (input_spec , list ) and all (
76+ isinstance (input_shape , tuple ) for input_shape in input_spec
77+ ):
78+ return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
79+ else :
80+ raise TypeError (f"Unsupported type { type (input_spec )} " )
5081
5182
5283def to_quantized_edge_program (
5384 model : torch .nn .Module ,
54- input_shapes : tuple [int , ...] | list [tuple [int , ...]],
85+ input_spec : tuple [ ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
5586 operators_not_to_delegate : list [str ] = None ,
87+ get_calibration_inputs_fn : Callable [
88+ [tuple [ModelInputSpec , ...]], list [tuple [torch .Tensor , ...]]
89+ ] = get_random_calibration_inputs ,
5690 target = "imxrt700" ,
5791 neutron_converter_flavor = "SDK_25_03" ,
5892 remove_quant_io_ops = False ,
5993 custom_delegation_options = CustomDelegationOptions (), # noqa B008
6094) -> EdgeProgramManager :
61- if isinstance (input_shapes , list ):
62- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
63- "For multiple inputs, provide" " list[tuple[int]]."
64- )
95+ calibration_inputs = get_calibration_inputs_fn (to_model_input_spec (input_spec ))
6596
66- calibration_inputs = [get_random_float_data (input_shapes ) for _ in range (4 )]
67- example_input = (
68- (torch .ones (input_shapes ),)
69- if type (input_shapes ) is tuple
70- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
71- )
97+ example_input = calibration_inputs [0 ]
7298
7399 exir_program_aten = torch .export .export_for_training (
74100 model , example_input , strict = True
@@ -104,27 +130,22 @@ def to_quantized_edge_program(
104130
105131
106132def to_quantized_executorch_program (
107- model : torch .nn .Module , input_shapes : tuple [int , ...] | list [tuple [int , ...]]
133+ model : torch .nn .Module ,
134+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
108135) -> ExecutorchProgramManager :
109- edge_program_manager = to_quantized_edge_program (model , input_shapes )
136+ edge_program_manager = to_quantized_edge_program (model , input_spec )
110137
111138 return edge_program_manager .to_executorch (
112139 config = ExecutorchBackendConfig (extract_delegate_segments = False )
113140 )
114141
115142
116143def to_edge_program (
117- model : nn .Module , input_shapes : tuple [int , ...] | list [tuple [int , ...]]
144+ model : nn .Module ,
145+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
118146) -> EdgeProgramManager :
119- if isinstance (input_shapes , list ):
120- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
121- "For multiple inputs, provide" " list[tuple[int]]."
122- )
147+ calibration_inputs = get_random_calibration_inputs (to_model_input_spec (input_spec ))
123148
124- example_input = (
125- (torch .ones (input_shapes ),)
126- if type (input_shapes ) is tuple
127- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
128- )
149+ example_input = calibration_inputs [0 ]
129150 exir_program = torch .export .export (model , example_input )
130151 return exir .to_edge (exir_program )
0 commit comments