3838)
3939from executorch .exir .passes import ToOutVarPass
4040from executorch .exir .passes .sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
41- from executorch .exir .program ._program import to_edge
41+ from executorch .exir .program ._program import _transform , to_edge
4242
4343from torch .export .exported_program import ExportedProgram
4444from torchao .quantization .pt2e .quantize_pt2e import convert_pt2e
@@ -145,22 +145,22 @@ def convert_pt2(
145145# fused model, to be able to get reference numerics.
146146# If this does not apply, please use quantize_pt2 instead.
147147def fuse_pt2 (
148- converted_graph_module : torch . fx . GraphModule ,
148+ converted_program : ExportedProgram ,
149149 quantizer : CadenceQuantizer ,
150- ) -> torch . fx . GraphModule :
150+ ) -> ExportedProgram :
151151 """
152- Fuse a converted graph module using the given quantizer.
152+ Fuse a converted exported program using the given quantizer.
153153 The quantizer must be the same as the one used to convert the model.
154154 If you do not expect that behavior, please use quantize_pt2 instead,
155155 which will instantiate a default quantizer for you if needed.
156- Returns a GraphModule with the fused model.
156+ Returns an ExportedProgram with the fused model.
157157 """
158158 # Get patterns and apply fusion of dq -> op -> q to qop
159159 # pyre-ignore[16]: no attribute
160160 patterns = [q .pattern for q in quantizer .quantizers ]
161- QuantFusion (patterns )( converted_graph_module )
161+ fused_program = _transform ( converted_program , QuantFusion (patterns ))
162162
163- return converted_graph_module
163+ return fused_program
164164
165165
166166# Note: quantizer is not optional here to force the user to supply a quantizer
@@ -210,7 +210,7 @@ def quantize_pt2(
210210 If calibration data is provided, it will be used to calibrate the model. If
211211 not, the inputs will be used for calibration instead, which is useful for
212212 unit tests but should not be used for end-to-end use cases.
213- Returns a GraphModule with the quantized model.
213+ Returns an ExportedProgram with the quantized model.
214214 Note: this function should not be called directly in general. Please use
215215 quantize_and_export_to_executorch for most needs.
216216 """
@@ -227,16 +227,15 @@ def quantize_pt2(
227227 dump_graphs = dump_graphs ,
228228 )
229229
230- # Get fused model
231- fused_gm = fuse_pt2 (converted_gm , quantizer )
230+ # Apply quant fusion to the exported program
231+ program = torch .export .export (converted_gm , inputs , strict = True )
232+ fused_program = fuse_pt2 (program , quantizer )
232233
233234 if dump_graphs :
234235 logging .info ("Graph after quantization and fusion:" )
235- logging .info (fused_gm .graph .print_tabular ())
236+ logging .info (fused_program . graph_module .graph .print_tabular ())
236237
237- program = torch .export .export (fused_gm , inputs , strict = True )
238-
239- return program
238+ return fused_program
240239
241240
242241TO_EDGE_OP_EXCEPTION_LIST : list [torch ._ops .OpOverload ] = [
0 commit comments