Skip to content

Commit 5ae38ac

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add quantize_and_export_to_edge and quantize_and_export_to_executorch (#10379)
Summary: Adding those APIs allows most users to use a single line for all of the compilation flow (vs quantizing first). Reviewed By: zonglinpeng Differential Revision: D73397438
1 parent 2837867 commit 5ae38ac

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

backends/cadence/aot/compiler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from .utils import print_ops_info
4848

4949

50+
default_quantizer = CadenceDefaultQuantizer()
51+
5052
# Note: this is not meant as a primary API since it can create inconsistencies
5153
# if the quantizer here is different from the quantizer used to convert. It is
5254
# however useful for unit tests to separate the converted model from the fused
@@ -250,6 +252,28 @@ def export_to_edge(
250252
return edge_prog_manager
251253

252254

255+
def quantize_and_export_to_edge(
256+
model: torch.nn.Module,
257+
inputs: tuple[object, ...],
258+
quantizer: Optional[CadenceQuantizer] = None,
259+
dump_graphs: bool = False,
260+
constant_methods: Optional[dict[str, object]] = None,
261+
) -> EdgeProgramManager:
262+
quantized_model = quantize_pt2(
263+
model,
264+
inputs,
265+
quantizer=quantizer,
266+
dump_graphs=dump_graphs,
267+
)
268+
269+
return export_to_edge(
270+
quantized_model,
271+
inputs,
272+
dump_graphs=dump_graphs,
273+
constant_methods=constant_methods,
274+
)
275+
276+
253277
def export_to_cadence(
254278
model: torch.nn.Module,
255279
inputs: tuple[object, ...],

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import executorch.backends.cadence.aot.ops_registrations # noqa
1313
import torch
1414
from executorch.backends.cadence.aot import compiler
15-
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
15+
from executorch.backends.cadence.aot.compiler import (
16+
export_to_edge,
17+
quantize_and_export_to_edge,
18+
)
1619
from executorch.backends.cadence.aot.fuse_ops import (
1720
FuseFullThenReshapePass,
1821
FuseMulIntoDequantPass,
@@ -414,9 +417,8 @@ def forward(self, x):
414417

415418
inputs = torch.randn(2, 12, 1, 6)
416419
model = M()
417-
quantized_model = quantize_pt2(model, (inputs,))
418420
graph_module = (
419-
export_to_edge(quantized_model, (inputs,)).exported_program().graph_module
421+
quantize_and_export_to_edge(model, (inputs,)).exported_program().graph_module
420422
)
421423
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
422424
self.check_op_counts(

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import torch
1414
import torch.nn.functional as F
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
16+
from executorch.backends.cadence.aot.compiler import (
17+
export_to_edge,
18+
quantize_and_export_to_edge,
19+
)
1720
from executorch.backends.cadence.aot.graph_builder import (
1821
GraphBuilder,
1922
single_op_builder,
@@ -851,9 +854,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar(
851854

852855
inputs = (x,)
853856
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
854-
quantized_model = quantize_pt2(model, inputs)
855857

856-
exported_program = export_to_edge(quantized_model, inputs).exported_program()
858+
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()
857859

858860
# By default, the quantized linear op should have constant scalar attributes.
859861
self.assertTargetCountsEqual(
@@ -898,9 +900,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_
898900

899901
inputs = (x,)
900902
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
901-
quantized_model = quantize_pt2(model, inputs)
902903

903-
exported_program = export_to_edge(quantized_model, inputs).exported_program()
904+
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()
904905

905906
# By default, the quantized linear op should have constant scalar attributes.
906907
self.assertTargetCountsEqual(

0 commit comments

Comments
 (0)