Skip to content

Commit c703417

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add quantize_and_export_to_edge and quantize_and_export_to_executorch
Summary: Adding those APIs allows most users to use a single line for all of the compilation flow (vs quantizing first). Reviewed By: zonglinpeng Differential Revision: D73397438
1 parent 59870c5 commit c703417

File tree

3 files changed

+36
-13
lines changed

3 files changed

+36
-13
lines changed

backends/cadence/aot/compiler.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
from .utils import print_ops_info
4848

4949

50+
default_quantizer = CadenceDefaultQuantizer()
51+
5052
# Note: this is not meant as a primary API since it can create inconsistencies
5153
# if the quantizer here is different from the quantizer used to convert. It is
5254
# however useful for unit tests to separate the converted model from the fused
@@ -145,7 +147,7 @@ def fuse_pt2(
145147
def quantize_pt2(
146148
model: torch.nn.Module,
147149
inputs: tuple[object, ...],
148-
quantizer: Optional[CadenceQuantizer] = None,
150+
quantizer: CadenceQuantizer = default_quantizer,
149151
calibration_data: Optional[list[tuple[object, ...]]] = None,
150152
dump_graphs: bool = False,
151153
) -> torch.fx.GraphModule:
@@ -159,10 +161,6 @@ def quantize_pt2(
159161
# Make the model inference mode by calling model.eval()
160162
model.eval()
161163

162-
# Instantiate the quantizer to CadenceQuantizer if not supplied
163-
if not quantizer:
164-
quantizer = CadenceDefaultQuantizer()
165-
166164
# Get converted graph module
167165
converted_gm = convert_pt2(
168166
model, inputs, quantizer, calibration_data, dump_graphs=dump_graphs
@@ -250,6 +248,28 @@ def export_to_edge(
250248
return edge_prog_manager
251249

252250

251+
def quantize_and_export_to_edge(
252+
model: torch.nn.Module,
253+
inputs: tuple[object, ...],
254+
quantizer: CadenceQuantizer = default_quantizer,
255+
dump_graphs: bool = False,
256+
constant_methods: Optional[dict[str, object]] = None,
257+
) -> EdgeProgramManager:
258+
quantized_model = quantize_pt2(
259+
model,
260+
inputs,
261+
quantizer=quantizer,
262+
dump_graphs=dump_graphs,
263+
)
264+
265+
return export_to_edge(
266+
quantized_model,
267+
inputs,
268+
dump_graphs=dump_graphs,
269+
constant_methods=constant_methods,
270+
)
271+
272+
253273
def export_to_cadence(
254274
model: torch.nn.Module,
255275
inputs: tuple[object, ...],

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
import executorch.backends.cadence.aot.ops_registrations # noqa
1313
import torch
1414
from executorch.backends.cadence.aot import compiler
15-
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
15+
from executorch.backends.cadence.aot.compiler import (
16+
export_to_edge,
17+
quantize_and_export_to_edge,
18+
)
1619
from executorch.backends.cadence.aot.fuse_ops import (
1720
FuseFullThenReshapePass,
1821
FuseMulIntoDequantPass,
@@ -394,9 +397,8 @@ def forward(self, x):
394397

395398
inputs = torch.randn(2, 12, 1, 6)
396399
model = M()
397-
quantized_model = quantize_pt2(model, (inputs,))
398400
graph_module = (
399-
export_to_edge(quantized_model, (inputs,)).exported_program().graph_module
401+
quantize_and_export_to_edge(model, (inputs,)).exported_program().graph_module
400402
)
401403
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
402404
self.check_op_counts(

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
import torch
1414
import torch.nn.functional as F
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
16+
from executorch.backends.cadence.aot.compiler import (
17+
export_to_edge,
18+
quantize_and_export_to_edge,
19+
)
1720
from executorch.backends.cadence.aot.graph_builder import (
1821
GraphBuilder,
1922
single_op_builder,
@@ -850,9 +853,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar(
850853

851854
inputs = (x,)
852855
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
853-
quantized_model = quantize_pt2(model, inputs)
854856

855-
exported_program = export_to_edge(quantized_model, inputs).exported_program()
857+
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()
856858

857859
# By default, the quantized linear op should have constant scalar attributes.
858860
self.assertTargetCountsEqual(
@@ -897,9 +899,8 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_
897899

898900
inputs = (x,)
899901
model = torch.nn.Linear(in_features=in_features, out_features=out_features)
900-
quantized_model = quantize_pt2(model, inputs)
901902

902-
exported_program = export_to_edge(quantized_model, inputs).exported_program()
903+
exported_program = quantize_and_export_to_edge(model, inputs).exported_program()
903904

904905
# By default, the quantized linear op should have constant scalar attributes.
905906
self.assertTargetCountsEqual(

0 commit comments

Comments
 (0)