5454# if the quantizer here is different from the quantizer used to convert. It is
5555# however useful for unit tests to separate the converted model from the fused
5656# model, to be able to get reference numerics.
57- # If this does not apply, please use quantize_and_fuse_pt2 instead.
57+ # If this does not apply, please use quantize_pt2 instead.
5858def trace (
5959 model : torch .nn .Module ,
6060 inputs : tuple [object , ...],
@@ -85,6 +85,29 @@ def trace(
8585
8686
8787def prepare_pt2 (
88+ model : torch .nn .Module ,
89+ inputs : tuple [object , ...],
90+ quantizer : CadenceQuantizer ,
91+ dump_graphs : bool = False ,
92+ ) -> torch .fx .GraphModule :
93+ """
94+ Trace and Prepare a model using the given quantizer.
95+ The quantizer must be supplied and be the same as the one used to
96+ fuse the model later, if applicable. If you do not expect that behavior,
97+ please use quantize_pt2 instead, which will instantiate a
98+ default quantizer for you if needed.
99+ Returns a GraphModule with the prepared model.
100+ """
101+
102+ traced_program = trace (model , inputs , dump_graphs = dump_graphs )
103+ prepared_program = prepare_traced_pt2 (
104+ traced_program , quantizer , dump_graphs = dump_graphs
105+ )
106+
107+ return prepared_program
108+
109+
110+ def prepare_traced_pt2 (
88111 program : ExportedProgram ,
89112 quantizer : CadenceQuantizer ,
90113 dump_graphs : bool = False ,
@@ -93,7 +116,7 @@ def prepare_pt2(
93116 Prepare a model using the given quantizer.
94117 The quantizer must be supplied and be the same as the one used to
95118 fuse the model later, if applicable. If you do not expect that behavior,
96- please use quantize_and_fuse_pt2 instead, which will instantiate a
119+ please use quantize_pt2 instead, which will instantiate a
97120 default quantizer for you if needed.
98121 Returns a GraphModule with the prepared model.
99122 """
@@ -137,7 +160,7 @@ def fuse_pt2(
137160 """
138161 Fuse a converted graph module using the given quantizer.
139162 The quantizer must be the same as the one used to convert the model.
140- If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
163+ If you do not expect that behavior, please use quantize_pt2 instead,
141164 which will instantiate a default quantizer for you if needed.
142165 Returns a GraphModule with the fused model.
143166 """
@@ -179,7 +202,7 @@ def quantize_pt2(
179202 logging .info (program .graph .print_tabular ())
180203
181204 # Get prepared graph module
182- prepared_gm = prepare_pt2 (program , quantizer , dump_graphs = dump_graphs )
205+ prepared_gm = prepare_pt2 (model , inputs , quantizer , dump_graphs = dump_graphs )
183206
184207 # Calibrate
185208 # If no calibration data is provided, use the inputs
0 commit comments