@@ -149,29 +149,17 @@ def fuse_pt2(
149149 return converted_graph_module
150150
151151
152- def quantize_pt2 (
153- model : torch .nn .Module ,
152+ # Note: quantizer is not optional here to force the user to supply a quantizer
153+ # and ensure consistency is more likely to be maintained.
154+ def get_fake_quant_model (model : torch .nn .Module ,
154155 inputs : tuple [object , ...],
155- quantizer : Optional [ CadenceQuantizer ] = None ,
156+ quantizer : CadenceQuantizer ,
156157 calibration_data : Optional [list [tuple [object , ...]]] = None ,
157158 dump_graphs : bool = False ,
158- ) -> ExportedProgram :
159- """
160- Trace, prepare, convert and fuse the model using the given quantizer.
161- If calibration data is provided, it will be used to calibrate the model. If
162- not, the inputs will be used for calibration instead, which is useful for
163- unit tests but should not be used for end-to-end use cases.
164- Returns a GraphModule with the quantized model.
165- Note: this function should not be called directly in general. Please use
166- quantize_and_export_to_executorch for most needs.
167- """
159+ ) -> torch .fx .GraphModule :
168160 # Make the model inference mode by calling model.eval()
169161 model .eval ()
170162
171- # Instantiate the quantizer to CadenceQuantizer if not supplied
172- if not quantizer :
173- quantizer = CadenceDefaultQuantizer ()
174-
175163 program = trace (model , inputs , dump_graphs = dump_graphs )
176164
177165 if dump_graphs :
@@ -191,6 +179,37 @@ def quantize_pt2(
191179
192180 # Get converted graph module
193181 converted_gm = convert_pt2 (prepared_gm , dump_graphs = dump_graphs )
182+ return converted_gm
183+
184+
185+ def quantize_pt2 (
186+ model : torch .nn .Module ,
187+ inputs : tuple [object , ...],
188+ quantizer : Optional [CadenceQuantizer ] = None ,
189+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
190+ dump_graphs : bool = False ,
191+ ) -> ExportedProgram :
192+ """
193+ Trace, prepare, convert and fuse the model using the given quantizer.
194+ If calibration data is provided, it will be used to calibrate the model. If
195+ not, the inputs will be used for calibration instead, which is useful for
196+ unit tests but should not be used for end-to-end use cases.
197+ Returns a GraphModule with the quantized model.
198+ Note: this function should not be called directly in general. Please use
199+ quantize_and_export_to_executorch for most needs.
200+ """
201+ # Instantiate the quantizer to CadenceQuantizer if not supplied
202+ if not quantizer :
203+ quantizer = CadenceDefaultQuantizer ()
204+
205+ # Get the converted (aka fake quant) graph module
206+ converted_gm = get_fake_quant_model (
207+ model ,
208+ inputs ,
209+ quantizer = quantizer ,
210+ calibration_data = calibration_data ,
211+ dump_graphs = dump_graphs
212+ )
194213
195214 # Get fused model
196215 fused_gm = fuse_pt2 (converted_gm , quantizer )
@@ -203,7 +222,6 @@ def quantize_pt2(
203222
204223 return program
205224
206-
207225TO_EDGE_OP_EXCEPTION_LIST : list [torch ._ops .OpOverload ] = [
208226 torch .ops .aten ._linalg_det .default ,
209227 torch .ops .aten ._linalg_svd .default ,
@@ -214,7 +232,7 @@ def quantize_pt2(
214232 torch .ops .aten .angle .default ,
215233 torch .ops .aten .rms_norm .default ,
216234]
217- TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload , ... ] = [
235+ TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload ] = [
218236 torch .ops .aten .rms_norm .default ,
219237]
220238
0 commit comments