Skip to content

Commit cb7acb4

Browse files
metascroyfacebook-github-bot
authored andcommitted
Enable quant fusion and const propagation by default
Summary: This diff enables quant fusion and constant propagation by default in ExecuTorch. It occurs after all to_edge passes, but before memory planning. Differential Revision: D73513516
1 parent 318e26b commit cb7acb4

File tree

5 files changed

+68
-13
lines changed

5 files changed

+68
-13
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
102102
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103103
# memory planned as the names must be serialized in that case.
104104
emit_mutable_buffer_names: bool = False
105+
106+
# If set to true, we run quant fusion and constant propagation passes
107+
do_quant_fusion_and_const_prop: bool = True

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ python_library(
154154
"//caffe2:torch",
155155
"//executorch/exir:pass_base",
156156
"//executorch/exir/dialects:lib",
157+
"//pytorch/ao:torchao",
158+
"//executorch/exir/passes:constant_prop_pass",
157159
],
158160
)
159161

exir/passes/quant_fusion_pass.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch.fx import GraphModule, subgraph_rewriter
1111
from torch.fx.passes.infra.pass_base import PassResult
1212
from torch.utils import _pytree as pytree
13+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
14+
from torch.export import ExportedProgram
1315

1416
from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
1517

@@ -139,3 +141,35 @@ def call(self, graph_module: GraphModule) -> PassResult:
139141
graph_module.graph.lint()
140142
graph_module.graph.eliminate_dead_code()
141143
return PassResult(graph_module, True)
144+
145+
146+
import torchao # noqa: F401
147+
_QDQ_OPS = [
148+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
149+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
150+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
151+
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
152+
exir_ops.edge.torchao.dequantize_affine,
153+
exir_ops.edge.torchao.dequantize_affine.default,
154+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
155+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
156+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
157+
exir_ops.edge.torchao.quantize_affine.default,
158+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
159+
exir_ops.edge.torchao.choose_qparams_affine.default,
160+
]
161+
162+
def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
163+
gm = program.graph_module
164+
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
165+
gm = gm_res.graph_module
166+
program.validate()
167+
168+
# Assert no Q/DQ ops remain in graph after quant fusion pass
169+
for node in gm.graph.nodes:
170+
if node.target in _QDQ_OPS:
171+
raise AssertionError(f"Q/DQ op {node.target} remains in graph after quant fusion pass")
172+
173+
# Do const prop pass to remove packing ops
174+
program = constant_prop_pass(program)
175+
return program

exir/program/_program.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.normalize_view_copy_base_pass import (
5353
NormalizeViewCopyBasePass,
5454
)
55+
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
5556
from executorch.exir.passes.remove_graph_asserts_pass import (
5657
RemoveGraphAssertsPass,
5758
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -1526,9 +1527,12 @@ def to_executorch(
15261527
after it has been transformed to the ExecuTorch backend.
15271528
"""
15281529
config = config if config else ExecutorchBackendConfig()
1529-
15301530
execution_programs: Dict[str, ExportedProgram] = {}
15311531
for name, program in self._edge_programs.items():
1532+
# Do constant propagation. This is needed for some quant fusion
1533+
# passes to work correctly
1534+
if config.do_quant_fusion_and_const_prop:
1535+
program = quant_fusion_and_const_prop_pass(program)
15321536
program = weights_to_outputs_pass(program)
15331537
program = unsafe_remove_auto_functionalized_pass(program)
15341538
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_quant_fusion_pass.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch import exir
1313
from executorch.exir import EdgeCompileConfig, to_edge
1414
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
15-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
15+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass, quant_fusion_and_const_prop_pass
1616
from executorch.exir.tests.common import register_additional_test_aten_ops
1717
from torch.ao.quantization import ( # @manual
1818
float_qparams_weight_only_qconfig,
@@ -33,7 +33,7 @@
3333
from torch.testing import FileCheck
3434
from torchao.quantization.granularity import PerAxis, PerGroup
3535
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
36-
36+
import copy
3737

3838
class TestQuantFusionPass(unittest.TestCase):
3939
@classmethod
@@ -419,6 +419,7 @@ def _test_embedding_torchao(
419419
m = to_edge(
420420
export(model, example_inputs, strict=True), compile_config=compile_config
421421
)
422+
m_copy = copy.deepcopy(m)
422423

423424
# Before pass, we see torchao dequantize and embedding ops
424425
FileCheck().check_count(
@@ -437,13 +438,9 @@ def _test_embedding_torchao(
437438

438439
# After pass, we see packing op and quantized embedding op, but no torchao dequantize op
439440
FileCheck().check_count(
440-
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
441-
1 if bit_width < 8 else 0,
442-
exactly=True,
441+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", 1 if bit_width < 8 else 0, exactly=True
443442
).check_count(
444-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
445-
1,
446-
exactly=True,
443+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
447444
).check_not(
448445
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
449446
).run(
@@ -454,9 +451,7 @@ def _test_embedding_torchao(
454451

455452
# After constant prop, we see quantized embedding op, but no packing op
456453
FileCheck().check_count(
457-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
458-
1,
459-
exactly=True,
454+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
460455
).check_not(
461456
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
462457
).run(
@@ -468,4 +463,21 @@ def _test_embedding_torchao(
468463
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
469464

470465
# Can lower to executorch
471-
exec_prog = m.to_executorch() # noqa: F841
466+
exec_prog = m.to_executorch() # noqa
467+
468+
469+
# Alternative flow 2 using quant_fusion_pass on exported program
470+
quant_fusion_and_const_prop_pass(m_copy.exported_program())
471+
FileCheck().check_count(
472+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
473+
).check_not(
474+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
475+
).run(
476+
m_copy.exported_program().graph_module.code
477+
)
478+
479+
actual_outputs2 = m_copy.exported_program().module()(*example_inputs)
480+
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))
481+
482+
# Can lower to executorch
483+
exec_prog2 = m_copy.to_executorch() # noqa

0 commit comments

Comments
 (0)