diff --git a/backends/cadence/aot/compiler.py b/backends/cadence/aot/compiler.py index 26a0437ac25..40807a87232 100644 --- a/backends/cadence/aot/compiler.py +++ b/backends/cadence/aot/compiler.py @@ -54,7 +54,7 @@ # if the quantizer here is different from the quantizer used to convert. It is # however useful for unit tests to separate the converted model from the fused # model, to be able to get reference numerics. -# If this does not apply, please use quantize_and_fuse_pt2 instead. +# If this does not apply, please use quantize_pt2 instead. def trace( model: torch.nn.Module, inputs: tuple[object, ...], @@ -85,6 +85,29 @@ def trace( def prepare_pt2( + model: torch.nn.Module, + inputs: tuple[object, ...], + quantizer: CadenceQuantizer, + dump_graphs: bool = False, +) -> torch.fx.GraphModule: + """ + Trace and Prepare a model using the given quantizer. + The quantizer must be supplied and be the same as the one used to + fuse the model later, if applicable. If you do not expect that behavior, + please use quantize_pt2 instead, which will instantiate a + default quantizer for you if needed. + Returns a GraphModule with the prepared model. + """ + + traced_program = trace(model, inputs, dump_graphs=dump_graphs) + prepared_program = prepare_traced_pt2( + traced_program, quantizer, dump_graphs=dump_graphs + ) + + return prepared_program + + +def prepare_traced_pt2( program: ExportedProgram, quantizer: CadenceQuantizer, dump_graphs: bool = False, @@ -93,7 +116,7 @@ def prepare_pt2( Prepare a model using the given quantizer. The quantizer must be supplied and be the same as the one used to fuse the model later, if applicable. If you do not expect that behavior, - please use quantize_and_fuse_pt2 instead, which will instantiate a + please use quantize_pt2 instead, which will instantiate a default quantizer for you if needed. Returns a GraphModule with the prepared model. """ @@ -137,7 +160,7 @@ def fuse_pt2( """ Fuse a converted graph module using the given quantizer. The quantizer must be the same as the one used to convert the model. - If you do not expect that behavior, please use quantize_and_fuse_pt2 instead, + If you do not expect that behavior, please use quantize_pt2 instead, which will instantiate a default quantizer for you if needed. Returns a GraphModule with the fused model. """ @@ -179,7 +202,7 @@ def quantize_pt2( logging.info(program.graph.print_tabular()) # Get prepared graph module - prepared_gm = prepare_pt2(program, quantizer, dump_graphs=dump_graphs) + prepared_gm = prepare_pt2(model, inputs, quantizer, dump_graphs=dump_graphs) # Calibrate # If no calibration data is provided, use the inputs diff --git a/backends/cadence/aot/export_example.py b/backends/cadence/aot/export_example.py index 3bf126fb400..3ad7076a1b7 100644 --- a/backends/cadence/aot/export_example.py +++ b/backends/cadence/aot/export_example.py @@ -19,7 +19,6 @@ export_to_executorch_gen_etrecord, fuse_pt2, prepare_pt2, - trace, ) from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer @@ -50,8 +49,12 @@ def export_model( # Instantiate the quantizer quantizer = CadenceDefaultQuantizer() - # Trace the model - ep = trace(model, example_inputs) + # Prepare the model + prepared_gm = prepare_pt2(model, example_inputs, quantizer) + + # Calibrate the model + for samples in [example_inputs]: + prepared_gm(*samples) # Prepare the model prepared_gm = prepare_pt2(ep, quantizer)