Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
# If this does not apply, please use quantize_pt2 instead.
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand Down Expand Up @@ -85,6 +85,29 @@ def trace(


def prepare_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Trace and Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the prepared model.
"""

traced_program = trace(model, inputs, dump_graphs=dump_graphs)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs
)

return prepared_program


def prepare_traced_pt2(
program: ExportedProgram,
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
Expand All @@ -93,7 +116,7 @@ def prepare_pt2(
Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the prepared model.
"""
Expand Down Expand Up @@ -137,7 +160,7 @@ def fuse_pt2(
"""
Fuse a converted graph module using the given quantizer.
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
If you do not expect that behavior, please use quantize_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns a GraphModule with the fused model.
"""
Expand Down Expand Up @@ -179,7 +202,7 @@ def quantize_pt2(
logging.info(program.graph.print_tabular())

# Get prepared graph module
prepared_gm = prepare_pt2(program, quantizer, dump_graphs=dump_graphs)
prepared_gm = prepare_pt2(model, inputs, quantizer, dump_graphs=dump_graphs)

# Calibrate
# If no calibration data is provided, use the inputs
Expand Down
9 changes: 6 additions & 3 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
export_to_executorch_gen_etrecord,
fuse_pt2,
prepare_pt2,
trace,
)

from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
Expand Down Expand Up @@ -50,8 +49,12 @@ def export_model(
# Instantiate the quantizer
quantizer = CadenceDefaultQuantizer()

# Trace the model
ep = trace(model, example_inputs)
# Prepare the model
prepared_gm = prepare_pt2(model, example_inputs, quantizer)

# Calibrate the model
for samples in [example_inputs]:
prepared_gm(*samples)

# Prepare the model
prepared_gm = prepare_pt2(ep, quantizer)
Expand Down
Loading