Skip to content

Commit 8349349

Browse files
authored
Enable quant fusion and const propagation by default (pytorch#10394)
Summary: This diff enables quant fusion and constant propagation by default in ExecuTorch. It occurs after all to_edge passes, but before memory planning. Differential Revision: D73513516
1 parent 7eba6d1 commit 8349349

File tree

8 files changed

+138
-21
lines changed

8 files changed

+138
-21
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
102102
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103103
# memory planned as the names must be serialized in that case.
104104
emit_mutable_buffer_names: bool = False
105+
106+
# If set to true, we run quant fusion and constant propagation passes
107+
do_quant_fusion_and_const_prop: bool = False

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ python_library(
154154
"//caffe2:torch",
155155
"//executorch/exir:pass_base",
156156
"//executorch/exir/dialects:lib",
157+
"//pytorch/ao:torchao",
158+
"//executorch/exir/passes:constant_prop_pass",
157159
],
158160
)
159161

exir/passes/constant_prop_pass.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-unsafe
88

9+
import logging
910
from collections import OrderedDict
1011
from typing import cast, Mapping, Optional
1112

@@ -29,6 +30,32 @@
2930
# Propagating aten.full can significantly increase compiled model size.
3031
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3132

33+
# Do not const prop quantization primitives
34+
_QDQ_OPS = [
35+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
38+
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
39+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
41+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
42+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
43+
]
44+
try:
45+
import torchao # noqa: F401
46+
47+
_QDQ_OPS.extend(
48+
[
49+
exir_ops.edge.torchao.dequantize_affine.default,
50+
exir_ops.edge.torchao.quantize_affine.default,
51+
exir_ops.edge.torchao.choose_qparams_affine.default,
52+
]
53+
)
54+
except ImportError:
55+
pass
56+
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))
57+
58+
3259
_PRIMITIVE_TYPES = (
3360
float,
3461
int,
@@ -308,7 +335,9 @@ def constant_prop_pass(
308335
if node.target == torch.ops.higher_order.cond
309336
]
310337
if len(has_control_flow) > 0:
311-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
338+
logging.warning(
339+
"constant_prop_pass does not constant propagate in control flow modules"
340+
)
312341

313342
const_node_to_tensor = get_propagated_const_tensor_dict(
314343
exported_program, custom_skip_targets

exir/passes/quant_fusion_pass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass
10+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
11+
from torch.export import ExportedProgram
1012
from torch.fx import GraphModule, subgraph_rewriter
1113
from torch.fx.passes.infra.pass_base import PassResult
1214
from torch.utils import _pytree as pytree
@@ -139,3 +141,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
139141
graph_module.graph.lint()
140142
graph_module.graph.eliminate_dead_code()
141143
return PassResult(graph_module, True)
144+
145+
146+
def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
147+
gm = program.graph_module
148+
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
149+
gm = gm_res.graph_module
150+
151+
# Do const prop pass to remove packing/dtype conversion ops
152+
program = constant_prop_pass(program)
153+
return program

exir/program/_program.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.normalize_view_copy_base_pass import (
5353
NormalizeViewCopyBasePass,
5454
)
55+
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
5556
from executorch.exir.passes.remove_graph_asserts_pass import (
5657
RemoveGraphAssertsPass,
5758
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -1524,9 +1525,15 @@ def to_executorch(
15241525
after it has been transformed to the ExecuTorch backend.
15251526
"""
15261527
config = config if config else ExecutorchBackendConfig()
1527-
15281528
execution_programs: Dict[str, ExportedProgram] = {}
15291529
for name, program in self._edge_programs.items():
1530+
if config.do_quant_fusion_and_const_prop:
1531+
if program.graph_signature.backward_signature is not None:
1532+
raise Exception(
1533+
"Cannot run do_quant_fusion_and_const_prop on a graph with a backward signature intended for on-device training."
1534+
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
1535+
)
1536+
program = quant_fusion_and_const_prop_pass(program)
15301537
program = weights_to_outputs_pass(program)
15311538
program = unsafe_remove_auto_functionalized_pass(program)
15321539
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_passes.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ class Module(torch.nn.Module):
12791279
def __init__(self):
12801280
super().__init__()
12811281
self.linear = torch.nn.Linear(3, 3)
1282+
self.w = torch.randn(3, 3)
12821283

12831284
def t(self, val):
12841285
return val + 1
@@ -1293,8 +1294,11 @@ def false_fn(self, val):
12931294
return self.linear(val) - self.f(val)
12941295

12951296
def forward(self, pred, x):
1297+
out = torch.nn.functional.linear(
1298+
x, self.w.to(torch.float16).to(torch.float32)
1299+
)
12961300
return torch.ops.higher_order.cond(
1297-
pred, self.true_fn, self.false_fn, [x]
1301+
pred, self.true_fn, self.false_fn, [out]
12981302
)
12991303

13001304
mod = Module()
@@ -1304,14 +1308,48 @@ def forward(self, pred, x):
13041308
export(mod, (pred, x), strict=True),
13051309
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
13061310
)
1307-
error_msg = r"constant_prop_pass for control flow is not supported yet."
1308-
1309-
# TODO(chenlai): enable constant prop pass for control flow
1310-
with self.assertRaisesRegex(
1311-
RuntimeError,
1312-
error_msg,
1313-
):
1314-
_ = constant_prop_pass(edge.exported_program())
1311+
expected_out = edge.exported_program().module()(pred, x)
1312+
1313+
warn_log = (
1314+
"constant_prop_pass does not constant propagate in control flow modules"
1315+
)
1316+
with self.assertLogs(level="WARNING") as log:
1317+
program = constant_prop_pass(edge.exported_program())
1318+
self.assertIn(warn_log, log.output[0])
1319+
1320+
out = program.module()(pred, x)
1321+
self.assertTrue(torch.allclose(expected_out, out))
1322+
1323+
# dtype casts in parent module are const propagated
1324+
FileCheck().check(
1325+
"executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant"
1326+
).run(program.graph_module.code)
1327+
1328+
def test_constant_prop_pass_quant_primitives(self) -> None:
1329+
class M(torch.nn.Module):
1330+
def __init__(self):
1331+
super().__init__()
1332+
self.w_int = torch.ones(3, 3, dtype=torch.int8)
1333+
self.w_scale = 3.0
1334+
self.w_zero_point = 3
1335+
1336+
def forward(self, x):
1337+
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1338+
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8
1339+
)
1340+
return torch.nn.functional.linear(x, w_dq)
1341+
1342+
mod = M()
1343+
x = torch.randn([3])
1344+
mod(x)
1345+
edge = to_edge(
1346+
export(mod, (x,), strict=True),
1347+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1348+
)
1349+
constant_prop_pass(edge.exported_program())
1350+
FileCheck().check(
1351+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
1352+
).run(edge.exported_program().graph_module.code)
13151353

13161354
def test_mutable_buffers(self) -> None:
13171355
def count_copies(gm: torch.fx.GraphModule) -> int:

exir/tests/test_quant_fusion_pass.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77
# pyre-strict
88

9+
import copy
910
import unittest
1011

1112
import torch
1213
from executorch import exir
1314
from executorch.exir import EdgeCompileConfig, to_edge
1415
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
15-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
16+
from executorch.exir.passes.quant_fusion_pass import (
17+
quant_fusion_and_const_prop_pass,
18+
QuantFusionPass,
19+
)
1620
from executorch.exir.tests.common import register_additional_test_aten_ops
1721
from torch.ao.quantization import ( # @manual
1822
float_qparams_weight_only_qconfig,
@@ -419,6 +423,7 @@ def _test_embedding_torchao(
419423
m = to_edge(
420424
export(model, example_inputs, strict=True), compile_config=compile_config
421425
)
426+
m_copy = copy.deepcopy(m)
422427

423428
# Before pass, we see torchao dequantize and embedding ops
424429
FileCheck().check_count(
@@ -468,4 +473,22 @@ def _test_embedding_torchao(
468473
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
469474

470475
# Can lower to executorch
471-
exec_prog = m.to_executorch() # noqa: F841
476+
exec_prog = m.to_executorch() # noqa
477+
478+
# Alternative flow 2 using quant_fusion_pass on exported program
479+
quant_fusion_and_const_prop_pass(m_copy.exported_program())
480+
FileCheck().check_count(
481+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
482+
1,
483+
exactly=True,
484+
).check_not(
485+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
486+
).run(
487+
m_copy.exported_program().graph_module.code
488+
)
489+
490+
actual_outputs2 = m_copy.exported_program().module()(*example_inputs)
491+
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))
492+
493+
# Can lower to executorch
494+
exec_prog2 = m_copy.to_executorch() # noqa

extension/llm/export/builder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from executorch.exir.pass_base import ExportPass
3131
from executorch.exir.passes import MemoryPlanningPass
32-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
3332
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3433

3534
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
@@ -504,20 +503,23 @@ def to_executorch(
504503
"""
505504
Lower the model to executorch and get an ExecutorchProgram.
506505
"""
507-
to_executorch_passes = [
508-
# If there are Linear operations left in the graph, let's execute
509-
# them with the optimized op_linear rather than materializing a
510-
# transpose followed by a regular op_mm.
511-
ConvertToLinearPass(),
512-
QuantFusionPass(),
513-
]
506+
to_executorch_passes = []
514507
if passes:
515508
# pyre-fixme[6]: In call `list.extend`, for 1st positional argument,
516509
# expected `Iterable[Union[ConvertToLinearPass, QuantFusionPass]]` but
517510
# got `List[ExportPass]
518511
to_executorch_passes.extend(passes)
519512

520513
assert self.edge_manager, "Need to run export_to_edge() first"
514+
515+
# If there are Linear operations left in the graph, let's execute
516+
# them with the optimized op_linear rather than materializing a
517+
# transpose followed by a regular op_mm.
518+
# TODO: ConvertToLinearPass is not a sound pass and must be called before
519+
# const propagation. It requires fixing:
520+
# https://github.com/pytorch/executorch/issues/10499
521+
self.edge_manager.transform([ConvertToLinearPass()])
522+
521523
self.export_program = self.edge_manager.to_executorch(
522524
ExecutorchBackendConfig(
523525
extract_delegate_segments=True,
@@ -526,6 +528,7 @@ def to_executorch(
526528
# Optional[PassResult]]]` but got `List[Union[ConvertToLinearPass,
527529
# QuantFusionPass]]`.
528530
passes=to_executorch_passes,
531+
do_quant_fusion_and_const_prop=True,
529532
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
530533
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
531534
)

0 commit comments

Comments
 (0)