33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from dataclasses import dataclass
7+ from typing import Callable
8+
9+ import numpy as np
610import torch
711
812from executorch import exir
2327from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e , prepare_pt2e
2428
2529
26- def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor ]]):
30+ @dataclass
31+ class ModelInputSpec :
32+ shape : tuple [int , ...]
33+ dtype : torch .dtype = torch .float32
34+
35+
36+ def _quantize_model (model , calibration_inputs : list [tuple [torch .Tensor , ...]]):
2737 quantizer = NeutronQuantizer ()
2838
2939 m = prepare_pt2e (model , quantizer )
@@ -34,34 +44,51 @@ def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
3444 return m
3545
3646
37- def get_random_float_data (input_shapes : tuple [int ] | list [tuple [int ]]):
38- # TODO: Replace with something more robust.
39- return (
40- (torch .randn (input_shapes ),)
41- if type (input_shapes ) is tuple
42- else tuple (torch .randn (input_shape ) for input_shape in input_shapes )
43- )
47+ def get_random_calibration_inputs (
48+ input_spec : tuple [ModelInputSpec , ...]
49+ ) -> list [tuple [torch .Tensor , ...]]:
50+ return [
51+ tuple ([torch .randn (spec .shape , dtype = spec .dtype ) for spec in input_spec ])
52+ for _ in range (4 )
53+ ]
54+
55+
56+ def to_model_input_spec (
57+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]]
58+ ) -> tuple [ModelInputSpec , ...]:
59+
60+ if isinstance (input_spec , tuple ) and all (
61+ isinstance (spec , ModelInputSpec ) for spec in input_spec
62+ ):
63+ return input_spec
64+
65+ elif isinstance (input_spec , tuple ) and all (
66+ isinstance (spec , int ) for spec in input_spec
67+ ):
68+ return (ModelInputSpec (input_spec ),)
69+
70+ elif isinstance (input_spec , list ) and all (
71+ isinstance (input_shape , tuple ) for input_shape in input_spec
72+ ):
73+ return tuple ([ModelInputSpec (spec ) for spec in input_spec ])
74+ else :
75+ raise TypeError (f"Unsupported type { type (input_spec )} " )
4476
4577
4678def to_quantized_edge_program (
4779 model : torch .nn .Module ,
48- input_shapes : tuple [int ] | list [tuple [int ]],
80+ input_spec : tuple [ModelInputSpec , ...] | tuple [ int , ... ] | list [tuple [int , ... ]],
4981 operators_not_to_delegate : list [str ] = None ,
82+ get_calibration_inputs_fn : Callable [
83+ [tuple [ModelInputSpec , ...]], list [tuple [torch .Tensor , ...]]
84+ ] = get_random_calibration_inputs ,
5085 target = "imxrt700" ,
5186 neutron_converter_flavor = "SDK_25_03" ,
5287 remove_quant_io_ops = False ,
5388) -> EdgeProgramManager :
54- if isinstance (input_shapes , list ):
55- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
56- "For multiple inputs, provide" " list[tuple[int]]."
57- )
89+ calibration_inputs = get_calibration_inputs_fn (to_model_input_spec (input_spec ))
5890
59- calibration_inputs = [get_random_float_data (input_shapes ) for _ in range (4 )]
60- example_input = (
61- (torch .ones (input_shapes ),)
62- if type (input_shapes ) is tuple
63- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
64- )
91+ example_input = calibration_inputs [0 ]
6592
6693 exir_program_aten = torch .export .export_for_training (
6794 model , example_input , strict = True
@@ -94,27 +121,22 @@ def to_quantized_edge_program(
94121
95122
96123def to_quantized_executorch_program (
97- model : torch .nn .Module , input_shapes : tuple [int ] | list [tuple [int ]]
124+ model : torch .nn .Module ,
125+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
98126) -> ExecutorchProgramManager :
99- edge_program_manager = to_quantized_edge_program (model , input_shapes )
127+ edge_program_manager = to_quantized_edge_program (model , input_spec )
100128
101129 return edge_program_manager .to_executorch (
102130 config = ExecutorchBackendConfig (extract_delegate_segments = False )
103131 )
104132
105133
106134def to_edge_program (
107- model : nn .Module , input_shapes : tuple [int ] | list [tuple [int ]]
135+ model : nn .Module ,
136+ input_spec : tuple [ModelInputSpec , ...] | tuple [int , ...] | list [tuple [int , ...]],
108137) -> EdgeProgramManager :
109- if isinstance (input_shapes , list ):
110- assert all (isinstance (input_shape , tuple ) for input_shape in input_shapes ), (
111- "For multiple inputs, provide" " list[tuple[int]]."
112- )
138+ calibration_inputs = get_random_calibration_inputs (to_model_input_spec (input_spec ))
113139
114- example_input = (
115- (torch .ones (input_shapes ),)
116- if type (input_shapes ) is tuple
117- else tuple (torch .ones (input_shape ) for input_shape in input_shapes )
118- )
140+ example_input = calibration_inputs [0 ]
119141 exir_program = torch .export .export (model , example_input )
120142 return exir .to_edge (exir_program )
0 commit comments