@@ -151,7 +151,7 @@ def quantize_pt2(
151151 quantizer : Optional [CadenceQuantizer ] = None ,
152152 calibration_data : Optional [list [tuple [object , ...]]] = None ,
153153 dump_graphs : bool = False ,
154- ) -> torch . fx . GraphModule :
154+ ) -> ExportedProgram :
155155 """
156156 Trace, prepare, convert and fuse the model using the given quantizer.
157157 If calibration data is provided, it will be used to calibrate the model. If
@@ -178,7 +178,9 @@ def quantize_pt2(
178178 logging .info ("Graph after quantization and fusion:" )
179179 logging .info (fused_gm .graph .print_tabular ())
180180
181- return fused_gm
181+ program = torch .export .export (fused_gm , inputs , strict = True )
182+
183+ return program
182184
183185
184186# Export the model and lower it to an ExportedProgram (in aten IR)
@@ -260,21 +262,43 @@ def quantize_and_export_to_edge(
260262 dump_graphs : bool = False ,
261263 constant_methods : Optional [dict [str , object ]] = None ,
262264) -> EdgeProgramManager :
265+ """
266+ Trace, quantize and lower a model/inputs pair to edge IR.
267+ """
263268 quantized_model = quantize_pt2 (
264269 model ,
265270 inputs ,
266271 quantizer = quantizer ,
267272 dump_graphs = dump_graphs ,
268273 )
269274
270- return export_to_edge (
275+ return lower_ep_to_edge (
271276 quantized_model ,
272- inputs ,
273277 dump_graphs = dump_graphs ,
274278 constant_methods = constant_methods ,
275279 )
276280
277281
282+ def lower_ep_to_cadence (
283+ program : ExportedProgram ,
284+ dump_graphs : bool = False ,
285+ opt_level : int = 1 ,
286+ ) -> EdgeProgramManager :
287+ """
288+ Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
289+ """
290+ edge_prog_manager = lower_ep_to_edge (program , dump_graphs = dump_graphs )
291+ cadence_passes = get_cadence_passes (opt_level )
292+
293+ # Run a couple required passes for quant/dequant ops
294+ cadence_prog_manager = edge_prog_manager .transform (
295+ cast (
296+ list [Callable [[torch .fx .GraphModule ], Optional [PassResult ]]], cadence_passes
297+ )
298+ )
299+ return cadence_prog_manager
300+
301+
278302def export_to_cadence (
279303 model : torch .nn .Module ,
280304 inputs : tuple [object , ...],
@@ -299,11 +323,14 @@ def quantize_and_export_to_cadence(
299323 dump_graphs : bool = False ,
300324 opt_level : int = 1 ,
301325) -> EdgeProgramManager :
326+ """
327+ Trace, quantize, lower a model/inputs pair to edge IR and apply frontend
328+ optimization passes.
329+ """
302330 quantized_model = quantize_pt2 (model , inputs )
303331
304- return export_to_cadence (
332+ return lower_ep_to_cadence (
305333 quantized_model ,
306- inputs ,
307334 opt_level = opt_level ,
308335 dump_graphs = dump_graphs ,
309336 )
0 commit comments