Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@
from .utils import print_ops_info


default_quantizer = CadenceDefaultQuantizer()


# Note: this is not meant as a primary API since it can create inconsistencies
# if the quantizer here is different from the quantizer used to convert. It is
# however useful for unit tests to separate the converted model from the fused
# model, to be able to get reference numerics.
# If this does not apply, please use quantize_and_fuse_pt2 instead.
def prepare_and_convert_pt2(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand Down Expand Up @@ -245,6 +253,28 @@ def export_to_edge(
return edge_prog_manager


def quantize_and_export_to_edge(
model: torch.nn.Module,
inputs: tuple[object, ...],
quantizer: Optional[CadenceQuantizer] = None,
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
) -> EdgeProgramManager:
quantized_model = quantize_pt2(
model,
inputs,
quantizer=quantizer,
dump_graphs=dump_graphs,
)

return export_to_edge(
quantized_model,
inputs,
dump_graphs=dump_graphs,
constant_methods=constant_methods,
)


def export_to_cadence(
model: torch.nn.Module,
inputs: tuple[object, ...],
Expand Down
10 changes: 7 additions & 3 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
from executorch.backends.cadence.aot.compiler import (
export_to_edge,
quantize_and_export_to_edge,
)
from executorch.backends.cadence.aot.fuse_ops import (
FuseFullThenReshapePass,
FuseMulIntoDequantPass,
Expand Down Expand Up @@ -414,9 +417,10 @@ def forward(self, x):

inputs = torch.randn(2, 12, 1, 6)
model = M()
quantized_model = quantize_pt2(model, (inputs,))
graph_module = (
export_to_edge(quantized_model, (inputs,)).exported_program().graph_module
quantize_and_export_to_edge(model, (inputs,))
.exported_program()
.graph_module
)
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
self.check_op_counts(
Expand Down
11 changes: 6 additions & 5 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import torch
import torch.nn.functional as F
from executorch.backends.cadence.aot import compiler
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
from executorch.backends.cadence.aot.compiler import (
export_to_edge,
quantize_and_export_to_edge,
)
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
single_op_builder,
Expand Down Expand Up @@ -851,9 +854,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar(

inputs = (x,)
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
quantized_model = quantize_pt2(model, inputs)

exported_program = export_to_edge(quantized_model, inputs).exported_program()
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()

# By default, the quantized linear op should have constant scalar attributes.
self.assertTargetCountsEqual(
Expand Down Expand Up @@ -898,9 +900,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_

inputs = (x,)
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
quantized_model = quantize_pt2(model, inputs)

exported_program = export_to_edge(quantized_model, inputs).exported_program()
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()

# By default, the quantized linear op should have constant scalar attributes.
self.assertTargetCountsEqual(
Expand Down
Loading