33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from dataclasses import dataclass
7+ from typing import Callable
8+
69import torch
710
811from executorch import exir
2326from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
2427
2528
26- def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor ]]):
29+ @dataclass
30+ class ModelInputSpec :
31+ shape : tuple [int , ...]
32+ dtype : torch .dtype = torch .float32
33+
34+
35+ def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor , ...]]):
2736 quantizer = NeutronQuantizer ()
2837
2938 m = prepare_pt2e (model , quantizer )
@@ -34,34 +43,51 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
3443 return m
3544
3645
37- def get_random_float_data (input_shapes : tuple [int ] | list [tuple [int ]]):
38- # TODO: Replace with something more robust.
39- return (
40- (torch .randn (input_shapes ),)
41- if type (input_shapes ) is tuple
42- else tuple (torch .randn (input_shape ) for input_shape in input_shapes )
43- )
46+ def get_random_calibration_inputs (
47+ input_spec : tuple [ModelInputSpec , ...]
48+ ) -> list [tuple [torch .Tensor , ...]]:
49+ return [
50+ tuple ([torch .randn (spec .shape , dtype = spec .dtype ) for spec in input_spec ])
51+ for _ in range (4 )
52+ ]
53+
54+
55+ def to_model_input_spec (
56+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]]
57+ ) -> tuple [ModelInputSpec , ...]:
58+
59+ if isinstance (input_spec , tuple ) and all (
60+ isinstance (spec , ModelInputSpec ) for spec in input_spec
61+ ):
62+ return input_spec
63+
64+ elif isinstance (input_spec , tuple ) and all (
65+ isinstance (spec , int ) for spec in input_spec
66+ ):
67+ return (ModelInputSpec (input_spec ),)
68+
69+ elif isinstance (input_spec , list ) and all (
70+ isinstance (input_shape , tuple ) for input_shape in input_spec
71+ ):
72+ return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
73+ else :
74+ raise TypeError (f"Unsupported type { type (input_spec )} " )
4475
4576
4677def to_quantized_edge_program (
4778 model : torch .nn .Module ,
48- input_shapes : tuple [int ] | list [tuple [int ]],
79+ input_spec : tuple [ModelInputSpec , ...] | tuple [ int , ... ] | list [tuple [int , ... ]],
4980 operators_not_to_delegate : list [str ] = None ,
81+ get_calibration_inputs_fn : Callable [
82+ [tuple [ModelInputSpec , ...]], list [tuple [torch .Tensor , ...]]
83+ ] = get_random_calibration_inputs ,
5084 target = "imxrt700" ,
5185 neutron_converter_flavor = "SDK_25_03" ,
5286 remove_quant_io_ops = False ,
5387) -> EdgeProgramManager :
54- if isinstance (input_shapes , list ):
55- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
56- "For multiple inputs, provide" " list[tuple[int]]."
57- )
88+ calibration_inputs = get_calibration_inputs_fn (to_model_input_spec (input_spec ))
5889
59- calibration_inputs = [get_random_float_data (input_shapes ) for _ in range (4 )]
60- example_input = (
61- (torch .ones (input_shapes ),)
62- if type (input_shapes ) is tuple
63- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
64- )
90+ example_input = calibration_inputs [0 ]
6591
6692 exir_program_aten = torch .export .export_for_training (
6793 model , example_input , strict = True
@@ -94,27 +120,22 @@ def to_quantized_edge_program(
94120
95121
96122def to_quantized_executorch_program (
97- model : torch .nn .Module , input_shapes : tuple [int ] | list [tuple [int ]]
123+ model : torch .nn .Module ,
124+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
98125) -> ExecutorchProgramManager :
99- edge_program_manager = to_quantized_edge_program (model , input_shapes )
126+ edge_program_manager = to_quantized_edge_program (model , input_spec )
100127
101128 return edge_program_manager .to_executorch (
102129 config = ExecutorchBackendConfig (extract_delegate_segments = False )
103130 )
104131
105132
106133def to_edge_program (
107- model : nn .Module , input_shapes : tuple [int ] | list [tuple [int ]]
134+ model : nn .Module ,
135+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
108136) -> EdgeProgramManager :
109- if isinstance (input_shapes , list ):
110- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
111- "For multiple inputs, provide" " list[tuple[int]]."
112- )
137+ calibration_inputs = get_random_calibration_inputs (to_model_input_spec (input_spec ))
113138
114- example_input = (
115- (torch .ones (input_shapes ),)
116- if type (input_shapes ) is tuple
117- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
118- )
139+ example_input = calibration_inputs [0 ]
119140 exir_program = torch .export .export (model , example_input )
120141 return exir .to_edge (exir_program )
0 commit comments