Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ python_library(
":ops_registrations",
":passes",
":replace_ops",
":compiler_funcs",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
Expand Down Expand Up @@ -332,6 +333,18 @@ python_library(
],
)

python_library(
name = "compiler_funcs",
srcs = [
"compiler_funcs.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//pytorch/ao:torchao",
],
)


python_unittest(
name = "test_graph_builder",
Expand Down
112 changes: 65 additions & 47 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.compiler_funcs import (
convert as convert_fn,
prepare as prepare_fn,
trace as trace_fn,
)
from executorch.backends.cadence.aot.memory_planning import (
CadenceMemoryPlanning,
print_memory_planning_info,
Expand All @@ -35,24 +40,21 @@
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge
from torch._inductor.decomposition import remove_decompositions

from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from .passes import apply_exir_ops_passes, apply_torch_ops_passes

from .utils import print_ops_info


default_quantizer = CadenceDefaultQuantizer()


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
# If this does not apply, please use quantize_pt2 instead.
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand All @@ -62,13 +64,6 @@ def trace(
Trace the model with export and return an ExportedProgram.
"""

# Make the model inference mode by calling model.eval()
model.eval()

# Get default decompositions
decomp_table = torch.export.default_decompositions()

# Select ops to keep
ops_to_keep = [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
Expand All @@ -78,63 +73,77 @@ def trace(
torch.ops.aten.rms_norm.default,
]

# Remove decompositions for the ops we want to keep
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
remove_decompositions(decomp_table, ops_to_keep)

# Export with dynamo
program = torch.export.export(model, inputs, strict=True).run_decompositions(
decomp_table
program = trace_fn(
model, inputs, is_qat=False, strict=True, ops_to_keep=ops_to_keep
)

if dump_graphs:
logging.info("Graph before quantization:")
logging.info(program.module().graph.print_tabular())
logging.info(program.graph_module.graph.print_tabular())

return program


def prepare_and_convert_pt2(
program: ExportedProgram,
def prepare_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: CadenceQuantizer,
calibration_data: Optional[list[tuple[object, ...]]] = None,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare and convert a model using the given quantizer.
Trace and Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_and_fuse_pt2 instead, which will instantiate a
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
If calibration data is provided, it will be used to calibrate the model. If
not, the inputs will be used for calibration instead, which is useful for
unit tests but should not be used for end-to-end use cases.
Returns a GraphModule with the converted model.
Returns a GraphModule with the prepared model.
"""

# Get the graph module from the ExportedProgram
model_gm = program.module()
traced_program = trace(model, inputs, dump_graphs=dump_graphs)
prepared_program = prepare_traced_pt2(
traced_program, quantizer, dump_graphs=dump_graphs
)

assert isinstance(model_gm, torch.fx.GraphModule)
return prepared_program

# Prepare
prepared_model = prepare_pt2e(model_gm, quantizer)

# Calibrate
# If no calibration data is provided, use the inputs
if calibration_data is None:
calibration_data = [inputs]
def prepare_traced_pt2(
program: ExportedProgram,
quantizer: CadenceQuantizer,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Prepare a model using the given quantizer.
The quantizer must be supplied and be the same as the one used to
fuse the model later, if applicable. If you do not expect that behavior,
please use quantize_pt2 instead, which will instantiate a
default quantizer for you if needed.
Returns a GraphModule with the prepared model.
"""

for samples in calibration_data:
prepared_model(*samples)
prepared_model = prepare_fn(program, quantizer, is_qat=False)

if dump_graphs:
logging.info("Graph after preparation:")
logging.info(prepared_model.graph.print_tabular())

return prepared_model


def convert_pt2(
graph_module: torch.fx.GraphModule,
dump_graphs: bool = False,
) -> torch.fx.GraphModule:
"""
Convert the model
Returns a GraphModule with the converted model.
"""

# Convert
converted_model = convert_pt2e(prepared_model)
converted_model = convert_fn(graph_module)

if dump_graphs:
logging.info("Graph after quantization (before fusion):")
logging.info(model_gm.graph.print_tabular())
logging.info("Graph after convert:")
logging.info(converted_model.graph.print_tabular())

return converted_model

Expand All @@ -151,7 +160,7 @@ def fuse_pt2(
"""
Fuse a converted graph module using the given quantizer.
The quantizer must be the same as the one used to convert the model.
If you do not expect that behavior, please use quantize_and_fuse_pt2 instead,
If you do not expect that behavior, please use quantize_pt2 instead,
which will instantiate a default quantizer for you if needed.
Returns a GraphModule with the fused model.
"""
Expand Down Expand Up @@ -192,10 +201,19 @@ def quantize_pt2(
logging.info("Graph after trace:")
logging.info(program.graph.print_tabular())

# Get prepared graph module
prepared_gm = prepare_pt2(model, inputs, quantizer, dump_graphs=dump_graphs)

# Calibrate
# If no calibration data is provided, use the inputs
if calibration_data is None:
calibration_data = [inputs]

for samples in calibration_data:
prepared_gm(*samples)

# Get converted graph module
converted_gm = prepare_and_convert_pt2(
program, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
)
converted_gm = convert_pt2(prepared_gm, dump_graphs=dump_graphs)

# Get fused model
fused_gm = fuse_pt2(converted_gm, quantizer)
Expand Down
63 changes: 63 additions & 0 deletions backends/cadence/aot/compiler_funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict


from typing import Optional

import torch
from torch._inductor.decomposition import remove_decompositions
from torchao.quantization.pt2e.quantize_pt2e import (
convert_pt2e,
prepare_pt2e,
prepare_qat_pt2e,
)
from torchao.quantization.pt2e.quantizer import Quantizer


@torch.no_grad()
def trace(
model: torch.nn.Module,
inputs: tuple[object, ...],
is_qat: bool = False,
strict: bool = False,
ops_to_keep: Optional[list[torch._ops.OpOverload]] = None,
) -> torch.export.ExportedProgram:
if is_qat:
model.train()
else:
model.eval()

decomp_table = torch.export.default_decompositions()
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Callable[..., typing.Any
remove_decompositions(decomp_table, ops_to_keep)
program = torch.export.export_for_training(
model, inputs, strict=strict
).run_decompositions(decomp_table)

return program


def prepare(
traced_program: torch.export.ExportedProgram,
quantizer: Quantizer,
is_qat: bool = False,
) -> torch.fx.GraphModule:
traced_model = traced_program.module()
assert isinstance(traced_model, torch.fx.GraphModule)

if is_qat:
prepared_model = prepare_qat_pt2e(traced_model, quantizer)
else:
prepared_model = prepare_pt2e(traced_model, quantizer)

return prepared_model


def convert(prepared_model: torch.fx.GraphModule) -> torch.fx.GraphModule:
converted_model = convert_pt2e(prepared_model)
return converted_model
14 changes: 9 additions & 5 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from typing import Any, Tuple

from executorch.backends.cadence.aot.compiler import (
convert_pt2,
export_to_executorch_gen_etrecord,
fuse_pt2,
prepare_and_convert_pt2,
trace,
prepare_pt2,
)

from executorch.backends.cadence.aot.quantizer.quantizer import CadenceDefaultQuantizer
Expand Down Expand Up @@ -49,11 +49,15 @@ def export_model(
# Instantiate the quantizer
quantizer = CadenceDefaultQuantizer()

# Trace the model
ep = trace(model, example_inputs)
# Prepare the model
prepared_gm = prepare_pt2(model, example_inputs, quantizer)

# Calibrate the model
for samples in [example_inputs]:
prepared_gm(*samples)

# Convert the model
converted_model = prepare_and_convert_pt2(ep, example_inputs, quantizer)
converted_model = convert_pt2(prepared_gm)

# Get reference outputs from converted model
ref_outputs = converted_model(*example_inputs)
Expand Down
Loading