@@ -172,29 +172,18 @@ def fuse_pt2(
172172 return converted_graph_module
173173
174174
175- def quantize_pt2 (
175+ # Note: quantizer is not optional here to force the user to supply a quantizer
176+ # and ensure consistency is more likely to be maintained.
177+ def get_fake_quant_model (
176178 model : torch .nn .Module ,
177179 inputs : tuple [object , ...],
178- quantizer : Optional [ CadenceQuantizer ] = None ,
180+ quantizer : CadenceQuantizer ,
179181 calibration_data : Optional [list [tuple [object , ...]]] = None ,
180182 dump_graphs : bool = False ,
181- ) -> ExportedProgram :
182- """
183- Trace, prepare, convert and fuse the model using the given quantizer.
184- If calibration data is provided, it will be used to calibrate the model. If
185- not, the inputs will be used for calibration instead, which is useful for
186- unit tests but should not be used for end-to-end use cases.
187- Returns a GraphModule with the quantized model.
188- Note: this function should not be called directly in general. Please use
189- quantize_and_export_to_executorch for most needs.
190- """
183+ ) -> torch .fx .GraphModule :
191184 # Make the model inference mode by calling model.eval()
192185 model .eval ()
193186
194- # Instantiate the quantizer to CadenceQuantizer if not supplied
195- if not quantizer :
196- quantizer = CadenceDefaultQuantizer ()
197-
198187 program = trace (model , inputs , dump_graphs = dump_graphs )
199188
200189 if dump_graphs :
@@ -214,6 +203,37 @@ def quantize_pt2(
214203
215204 # Get converted graph module
216205 converted_gm = convert_pt2 (prepared_gm , dump_graphs = dump_graphs )
206+ return converted_gm
207+
208+
209+ def quantize_pt2 (
210+ model : torch .nn .Module ,
211+ inputs : tuple [object , ...],
212+ quantizer : Optional [CadenceQuantizer ] = None ,
213+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
214+ dump_graphs : bool = False ,
215+ ) -> ExportedProgram :
216+ """
217+ Trace, prepare, convert and fuse the model using the given quantizer.
218+ If calibration data is provided, it will be used to calibrate the model. If
219+ not, the inputs will be used for calibration instead, which is useful for
220+ unit tests but should not be used for end-to-end use cases.
221+ Returns a GraphModule with the quantized model.
222+ Note: this function should not be called directly in general. Please use
223+ quantize_and_export_to_executorch for most needs.
224+ """
225+ # Instantiate the quantizer to CadenceQuantizer if not supplied
226+ if not quantizer :
227+ quantizer = CadenceDefaultQuantizer ()
228+
229+ # Get the converted (aka fake quant) graph module
230+ converted_gm = get_fake_quant_model (
231+ model ,
232+ inputs ,
233+ quantizer = quantizer ,
234+ calibration_data = calibration_data ,
235+ dump_graphs = dump_graphs ,
236+ )
217237
218238 # Get fused model
219239 fused_gm = fuse_pt2 (converted_gm , quantizer )
@@ -237,7 +257,7 @@ def quantize_pt2(
237257 torch .ops .aten .angle .default ,
238258 torch .ops .aten .rms_norm .default ,
239259]
240- TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload , ... ] = [
260+ TO_EDGE_PRESERVE_OPS : list [torch ._ops .OpOverload ] = [
241261 torch .ops .aten .rms_norm .default ,
242262]
243263
0 commit comments