Skip to content

Commit b19c675

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Introduce the get_fake_quant_model API (#12997)
Summary: As titled. This way, people can get the converted model with quantized numerics and use it to compare with implementation outputs. Differential Revision: D79105110
1 parent 275adee commit b19c675

File tree

1 file changed

+37
-19
lines changed

1 file changed

+37
-19
lines changed

backends/cadence/aot/compiler.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -149,29 +149,17 @@ def fuse_pt2(
149149
return converted_graph_module
150150

151151

152-
def quantize_pt2(
153-
model: torch.nn.Module,
152+
# Note: quantizer is not optional here to force the user to supply a quantizer
153+
# and ensure consistency is more likely to be maintained.
154+
def get_fake_quant_model(model: torch.nn.Module,
154155
inputs: tuple[object, ...],
155-
quantizer: Optional[CadenceQuantizer] = None,
156+
quantizer: CadenceQuantizer,
156157
calibration_data: Optional[list[tuple[object, ...]]] = None,
157158
dump_graphs: bool = False,
158-
) -> ExportedProgram:
159-
"""
160-
Trace, prepare, convert and fuse the model using the given quantizer.
161-
If calibration data is provided, it will be used to calibrate the model. If
162-
not, the inputs will be used for calibration instead, which is useful for
163-
unit tests but should not be used for end-to-end use cases.
164-
Returns a GraphModule with the quantized model.
165-
Note: this function should not be called directly in general. Please use
166-
quantize_and_export_to_executorch for most needs.
167-
"""
159+
) -> torch.fx.GraphModule:
168160
# Make the model inference mode by calling model.eval()
169161
model.eval()
170162

171-
# Instantiate the quantizer to CadenceQuantizer if not supplied
172-
if not quantizer:
173-
quantizer = CadenceDefaultQuantizer()
174-
175163
program = trace(model, inputs, dump_graphs=dump_graphs)
176164

177165
if dump_graphs:
@@ -191,6 +179,37 @@ def quantize_pt2(
191179

192180
# Get converted graph module
193181
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)
182+
return converted_gm
183+
184+
185+
def quantize_pt2(
186+
model: torch.nn.Module,
187+
inputs: tuple[object, ...],
188+
quantizer: Optional[CadenceQuantizer] = None,
189+
calibration_data: Optional[list[tuple[object, ...]]] = None,
190+
dump_graphs: bool = False,
191+
) -> ExportedProgram:
192+
"""
193+
Trace, prepare, convert and fuse the model using the given quantizer.
194+
If calibration data is provided, it will be used to calibrate the model. If
195+
not, the inputs will be used for calibration instead, which is useful for
196+
unit tests but should not be used for end-to-end use cases.
197+
Returns a GraphModule with the quantized model.
198+
Note: this function should not be called directly in general. Please use
199+
quantize_and_export_to_executorch for most needs.
200+
"""
201+
# Instantiate the quantizer to CadenceQuantizer if not supplied
202+
if not quantizer:
203+
quantizer = CadenceDefaultQuantizer()
204+
205+
# Get the converted (aka fake quant) graph module
206+
converted_gm = get_fake_quant_model(
207+
model,
208+
inputs,
209+
quantizer=quantizer,
210+
calibration_data=calibration_data,
211+
dump_graphs=dump_graphs
212+
)
194213

195214
# Get fused model
196215
fused_gm = fuse_pt2(converted_gm, quantizer)
@@ -203,7 +222,6 @@ def quantize_pt2(
203222

204223
return program
205224

206-
207225
TO_EDGE_OP_EXCEPTION_LIST: list[torch._ops.OpOverload] = [
208226
torch.ops.aten._linalg_det.default,
209227
torch.ops.aten._linalg_svd.default,
@@ -214,7 +232,7 @@ def quantize_pt2(
214232
torch.ops.aten.angle.default,
215233
torch.ops.aten.rms_norm.default,
216234
]
217-
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [
235+
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload] = [
218236
torch.ops.aten.rms_norm.default,
219237
]
220238

0 commit comments

Comments
 (0)