3939from torch ._inductor .decomposition import remove_decompositions
4040from torch .ao .quantization .quantize_pt2e import convert_pt2e , prepare_pt2e
4141
42- from torch .export import export
4342from torch .export .exported_program import ExportedProgram
4443
4544from .passes import get_cadence_passes
5554# however useful for unit tests to separate the converted model from the fused
5655# model, to be able to get reference numerics.
5756# If this does not apply, please use quantize_and_fuse_pt2 instead.
58- def prepare_and_convert_pt2 (
57+ def trace (
5958 model : torch .nn .Module ,
6059 inputs : tuple [object , ...],
61- quantizer : CadenceQuantizer ,
62- calibration_data : Optional [list [tuple [object , ...]]] = None ,
6360 dump_graphs : bool = False ,
64- ) -> torch . fx . GraphModule :
61+ ) -> ExportedProgram :
6562 """
66- Prepare and convert a model using the given quantizer.
67- The quantizer must be supplied and be the same as the one used to
68- fuse the model later, if applicable. If you do not expect that behavior,
69- please use quantize_and_fuse_pt2 instead, which will instantiate a
70- default quantizer for you if needed.
71- If calibration data is provided, it will be used to calibrate the model. If
72- not, the inputs will be used for calibration instead, which is useful for
73- unit tests but should not be used for end-to-end use cases.
74- Returns a GraphModule with the converted model.
63+ Trace the model with export_for_training and return an ExportedProgram.
7564 """
7665
66+ # Make the model inference mode by calling model.eval()
67+ model .eval ()
68+
69+ # Prevent mkldnn decompositions
70+ torch ._C ._set_mkldnn_enabled (False )
71+
7772 # Get default decompositions
7873 decomp_table = torch .export .default_decompositions ()
74+
7975 # Select ops to keep
8076 ops_to_keep = [
8177 torch .ops .aten .conv1d .default ,
@@ -85,19 +81,46 @@ def prepare_and_convert_pt2(
8581 torch .ops .aten .matmul .default ,
8682 torch .ops .aten .rms_norm .default ,
8783 ]
84+
8885 # Remove decompositions for the ops we want to keep
8986 # pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
9087 remove_decompositions (decomp_table , ops_to_keep )
88+
9189 # Export with dynamo
92- model_gm = (
93- torch .export .export_for_training (model , inputs , strict = True )
94- .run_decompositions (decomp_table )
95- .module ()
96- )
90+ program = torch .export .export_for_training (
91+ model , inputs , strict = True
92+ ).run_decompositions (decomp_table )
9793
9894 if dump_graphs :
9995 logging .info ("Graph before quantization:" )
100- logging .info (model_gm .graph .print_tabular ())
96+ logging .info (program .module ().graph .print_tabular ())
97+
98+ return program
99+
100+
101+ def prepare_and_convert_pt2 (
102+ program : ExportedProgram ,
103+ inputs : tuple [object , ...],
104+ quantizer : CadenceQuantizer ,
105+ calibration_data : Optional [list [tuple [object , ...]]] = None ,
106+ dump_graphs : bool = False ,
107+ ) -> torch .fx .GraphModule :
108+ """
109+ Prepare and convert a model using the given quantizer.
110+ The quantizer must be supplied and be the same as the one used to
111+ fuse the model later, if applicable. If you do not expect that behavior,
112+ please use quantize_and_fuse_pt2 instead, which will instantiate a
113+ default quantizer for you if needed.
114+ If calibration data is provided, it will be used to calibrate the model. If
115+ not, the inputs will be used for calibration instead, which is useful for
116+ unit tests but should not be used for end-to-end use cases.
117+ Returns a GraphModule with the converted model.
118+ """
119+
120+ # Get the graph module from the ExportedProgram
121+ model_gm = program .module ()
122+
123+ assert isinstance (model_gm , torch .fx .GraphModule )
101124
102125 # Prepare
103126 prepared_model = prepare_pt2e (model_gm , quantizer )
@@ -121,10 +144,10 @@ def prepare_and_convert_pt2(
121144
122145
123146# Note: this is not meant as a primary API since it can create inconsistencies
124- # if the quantizer here is different from the quantizer used to convert. It is
125- # however useful for unit tests to separate the converted model from the fused
126- # model, to be able to get reference numerics.
127- # If this does not apply, please use quantize_and_fuse_pt2 instead.
147+ # if the quantizer here is different from the quantizer used to prepare/ convert.
148+ # It is however useful for unit tests to separate the converted model from the
149+ # fused model, to be able to get reference numerics.
150+ # If this does not apply, please use quantize_pt2 instead.
128151def fuse_pt2 (
129152 converted_graph_module : torch .fx .GraphModule ,
130153 quantizer : CadenceQuantizer ,
@@ -167,9 +190,15 @@ def quantize_pt2(
167190 if not quantizer :
168191 quantizer = CadenceDefaultQuantizer ()
169192
193+ program = trace (model , inputs , dump_graphs = dump_graphs )
194+
195+ if dump_graphs :
196+ logging .info ("Graph after trace:" )
197+ logging .info (program .graph .print_tabular ())
198+
170199 # Get converted graph module
171200 converted_gm = prepare_and_convert_pt2 (
172- model , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
201+ program , inputs , quantizer , calibration_data , dump_graphs = dump_graphs
173202 )
174203
175204 # Get fused model
@@ -184,22 +213,6 @@ def quantize_pt2(
184213 return program
185214
186215
187- # Export the model and lower it to an ExportedProgram (in aten IR)
188- def export_program (
189- model : torch .nn .Module ,
190- inputs : tuple [object , ...],
191- ) -> ExportedProgram :
192- assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
193-
194- # Prevent mkldnn decompositions
195- torch ._C ._set_mkldnn_enabled (False )
196-
197- # Export the model and return it.
198- expo_program = export (model , inputs , strict = True )
199-
200- return expo_program
201-
202-
203216def _lower_ep_to_edge (
204217 expo_program : ExportedProgram ,
205218 dump_graphs : bool = False ,
@@ -248,7 +261,7 @@ def export_to_edge(
248261 assert isinstance (model , torch .nn .Module ), "model should be an nn.Module"
249262
250263 # Export the model into an ExportedProgram.
251- expo_program = export_program (model , inputs )
264+ expo_program = trace (model , inputs )
252265
253266 # Lower the model to edge IR.
254267 edge_prog_manager = _lower_ep_to_edge (expo_program , dump_graphs , constant_methods )
0 commit comments