Skip to content

Commit 7aa12ab

Browse files
metascroyfacebook-github-bot
authored andcommitted
Enable quant fusion and const propagation by default (pytorch#10394)
Summary: This diff enables quant fusion and constant propagation by default in ExecuTorch. It occurs after all to_edge passes, but before memory planning. Differential Revision: D73513516
1 parent 6b877de commit 7aa12ab

File tree

11 files changed

+166
-52
lines changed

11 files changed

+166
-52
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
102102
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103103
# memory planned as the names must be serialized in that case.
104104
emit_mutable_buffer_names: bool = False
105+
106+
# If set to true, we run quant fusion and constant propagation passes
107+
do_quant_fusion_and_const_prop: bool = True

exir/emit/test/test_emit.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
431431
.executorch_program
432432
)
433433
# The value for beta should appear before alpha
434-
self.assertEqual(program.execution_plan[0].values[12].val, Int(3))
435-
self.assertEqual(program.execution_plan[0].values[13].val, Int(2))
434+
self.assertEqual(program.execution_plan[0].values[4].val, Int(3))
435+
self.assertEqual(program.execution_plan[0].values[5].val, Int(2))
436436

437437
def test_kwargs2(self) -> None:
438438
"""Tests that the kwargs are placed in the order specified by
@@ -451,10 +451,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
451451
to_edge(export(f, (x,), strict=True)).to_executorch().executorch_program
452452
)
453453
# The value for right should appear before side
454-
self.assertEqual(program.execution_plan[0].values[6].val, Bool(False))
455-
self.assertEqual(program.execution_plan[0].values[7].val, Bool(True))
456-
self.assertEqual(program.execution_plan[0].values[8].val, String("right"))
457-
self.assertEqual(program.execution_plan[0].values[9].val, Null())
454+
self.assertEqual(program.execution_plan[0].values[3].val, Bool(False))
455+
self.assertEqual(program.execution_plan[0].values[4].val, Bool(True))
456+
self.assertEqual(program.execution_plan[0].values[5].val, String("right"))
457+
self.assertEqual(program.execution_plan[0].values[6].val, Null())
458458

459459
def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None:
460460
instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args
@@ -532,24 +532,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
532532
# Check the mul operator's stack trace contains f -> g -> h
533533
self.assertTrue(
534534
"return torch.mul(x, torch.randn(3, 2))"
535-
in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
535+
in program.execution_plan[0].chains[0].stacktrace[0].items[-1].context
536536
)
537537
self.assertEqual(
538-
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
538+
program.execution_plan[0].chains[0].stacktrace[0].items[-1].name, "f"
539539
)
540540
self.assertEqual(
541-
program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g"
541+
program.execution_plan[0].chains[0].stacktrace[0].items[-2].name, "g"
542542
)
543543
self.assertEqual(
544-
program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward"
544+
program.execution_plan[0].chains[0].stacktrace[0].items[-3].name, "forward"
545545
)
546546

547547
# Check the sin operator's stack trace contains g -> h
548548
self.assertEqual(
549-
program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g"
549+
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "g"
550550
)
551551
self.assertEqual(
552-
program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward"
552+
program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "forward"
553553
)
554554

555555
def test_stacktrace_off(self) -> None:
@@ -878,10 +878,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878878
.executorch_program.execution_plan[0]
879879
.non_const_buffer_sizes
880880
)
881-
881+
882+
config = ExecutorchBackendConfig(
883+
do_quant_fusion_and_const_prop=False,
884+
)
882885
edge_program_manager = to_edge(export(f, (torch.ones(3, 2),), strict=True))
883886
non_const_buffer_size_without_const_prop_pass = (
884-
edge_program_manager.to_executorch()
887+
edge_program_manager.to_executorch(config)
885888
.executorch_program.execution_plan[0]
886889
.non_const_buffer_sizes
887890
)
@@ -1510,7 +1513,12 @@ def forward(self, x):
15101513
self.assertEqual(model.W1.untyped_storage().nbytes(), 8)
15111514
self.assertEqual(model.W2.nbytes, 4)
15121515
self.assertEqual(model.W2.untyped_storage().nbytes(), 8)
1513-
program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch()
1516+
1517+
# Without this, the views get
1518+
config = exir.ExecutorchBackendConfig(
1519+
do_quant_fusion_and_const_prop=False,
1520+
)
1521+
program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch(config)
15141522

15151523
program = program._emitter_output.program
15161524
# each emitted weight is not a view
@@ -1531,7 +1539,10 @@ def forward(self, x):
15311539
program = program._emitter_output.program
15321540
# confirm that the buffer was emitted
15331541
self.assertEqual(len(program.constant_buffer), 2)
1534-
self.assertEqual(len(program.constant_buffer[1].storage), 8)
1542+
1543+
# executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default
1544+
# converts the buffer from i64 to fp32 (4 bytes), which gets const propagated
1545+
self.assertEqual(len(program.constant_buffer[1].storage), 4)
15351546

15361547
def test_emit_lifted_tensor_constant(self) -> None:
15371548
class LiftedTensorConstants(nn.Module):
@@ -1544,7 +1555,7 @@ def forward(self, x):
15441555

15451556
model = LiftedTensorConstants()
15461557
# Specify that we want to move non-lifted constants to external file
1547-
et_cfg = ExecutorchBackendConfig(external_constants=True)
1558+
et_cfg = ExecutorchBackendConfig(external_constants=True, do_quant_fusion_and_const_prop=False)
15481559
program = to_edge(
15491560
export(model, (torch.ones(3, 2),), strict=True)
15501561
).to_executorch(et_cfg)
@@ -1566,7 +1577,7 @@ def forward(self, x):
15661577

15671578
model = LiftedConstants()
15681579
# Specify that we want to move non-lifted constants to external file
1569-
et_cfg = ExecutorchBackendConfig(external_constants=True)
1580+
et_cfg = ExecutorchBackendConfig(external_constants=True, do_quant_fusion_and_const_prop=False)
15701581
program = to_edge(
15711582
export(model, (torch.ones(3, 2),), strict=True)
15721583
).to_executorch(et_cfg)
@@ -1658,7 +1669,10 @@ def forward(self, x):
16581669
model = to_edge(export(InfinityMaskModel(), (torch.randn(2, 2),), strict=True))
16591670

16601671
# Confirm that we can serialize the model with infinity in it.
1661-
model = model.to_executorch()
1672+
config = ExecutorchBackendConfig(
1673+
do_quant_fusion_and_const_prop=False,
1674+
)
1675+
model = model.to_executorch(config)
16621676

16631677
# Assert that the infinity is stored as a string "-inf".
16641678
values = model.executorch_program.execution_plan[0].values
@@ -1716,8 +1730,8 @@ def forward(self, x):
17161730
external_map = emitter_output.external_constant_map[
17171731
"_default_external_constant"
17181732
]
1719-
self.assertEqual(external_map["linear.weight"], 0)
1720-
self.assertEqual(external_map["linear.bias"], 1)
1733+
self.assertEqual(external_map["_prop_tensor_constant0"], 1)
1734+
self.assertEqual(external_map["linear.bias"], 0)
17211735

17221736
def test_delegate_deduplicate(self) -> None:
17231737
class SharedModule(torch.nn.Module):
@@ -1804,7 +1818,7 @@ def forward(self, input, label):
18041818
ep = to_edge(ep)
18051819
# Lower the graph to executorch.
18061820
ep = ep.to_executorch(
1807-
config=ExecutorchBackendConfig(external_mutable_weights=True)
1821+
config=ExecutorchBackendConfig(external_mutable_weights=True, do_quant_fusion_and_const_prop=False)
18081822
)
18091823

18101824
emitter_output = ep._emitter_output

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ python_library(
154154
"//caffe2:torch",
155155
"//executorch/exir:pass_base",
156156
"//executorch/exir/dialects:lib",
157+
"//pytorch/ao:torchao",
158+
"//executorch/exir/passes:constant_prop_pass",
157159
],
158160
)
159161

exir/passes/constant_prop_pass.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from collections import OrderedDict
1010
from typing import cast, Mapping, Optional
11+
import logging
1112

1213
import torch
1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -29,6 +30,31 @@
2930
# Propagating aten.full can significantly increase compiled model size.
3031
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3132

33+
# Do not const prop quantization primitives
34+
_QDQ_OPS = [
35+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
38+
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
39+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
41+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
42+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
43+
]
44+
try:
45+
import torchao # noqa: F401
46+
_QDQ_OPS.extend(
47+
[
48+
exir_ops.edge.torchao.dequantize_affine.default,
49+
exir_ops.edge.torchao.quantize_affine.default,
50+
exir_ops.edge.torchao.choose_qparams_affine.default,
51+
]
52+
)
53+
except ImportError:
54+
pass
55+
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))
56+
57+
3258
_PRIMITIVE_TYPES = (
3359
float,
3460
int,
@@ -40,7 +66,6 @@
4066
torch.layout,
4167
)
4268

43-
4469
def is_const(
4570
arg,
4671
exported_program: ExportedProgram,
@@ -308,7 +333,7 @@ def constant_prop_pass(
308333
if node.target == torch.ops.higher_order.cond
309334
]
310335
if len(has_control_flow) > 0:
311-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
336+
logging.warning("constant_prop_pass does not constant propagate in control flow modules")
312337

313338
const_node_to_tensor = get_propagated_const_tensor_dict(
314339
exported_program, custom_skip_targets

exir/passes/quant_fusion_pass.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch.fx import GraphModule, subgraph_rewriter
1111
from torch.fx.passes.infra.pass_base import PassResult
1212
from torch.utils import _pytree as pytree
13+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
14+
from torch.export import ExportedProgram
1315

1416
from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
1517

@@ -139,3 +141,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
139141
graph_module.graph.lint()
140142
graph_module.graph.eliminate_dead_code()
141143
return PassResult(graph_module, True)
144+
145+
146+
def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
147+
gm = program.graph_module
148+
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
149+
gm = gm_res.graph_module
150+
151+
# Do const prop pass to remove packing/dtype conversion ops
152+
program = constant_prop_pass(program)
153+
return program

exir/program/_program.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.normalize_view_copy_base_pass import (
5353
NormalizeViewCopyBasePass,
5454
)
55+
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
5556
from executorch.exir.passes.remove_graph_asserts_pass import (
5657
RemoveGraphAssertsPass,
5758
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -1524,9 +1525,15 @@ def to_executorch(
15241525
after it has been transformed to the ExecuTorch backend.
15251526
"""
15261527
config = config if config else ExecutorchBackendConfig()
1527-
15281528
execution_programs: Dict[str, ExportedProgram] = {}
15291529
for name, program in self._edge_programs.items():
1530+
if config.do_quant_fusion_and_const_prop:
1531+
if program.graph_signature.backward_signature is not None:
1532+
raise Exception(
1533+
"Cannot run do_quant_fusion_and_const_prop on a graph with a backward signature intended for on-device training."
1534+
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
1535+
)
1536+
program = quant_fusion_and_const_prop_pass(program)
15301537
program = weights_to_outputs_pass(program)
15311538
program = unsafe_remove_auto_functionalized_pass(program)
15321539
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_memory_planning.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,8 @@ def forward(self, input, label):
769769
ep = export(net, inputs, strict=True)
770770
ep = _export_forward_backward(ep)
771771
ep = to_edge(ep)
772-
ep = ep.to_executorch()
772+
config = ExecutorchBackendConfig(do_quant_fusion_and_const_prop=False)
773+
ep = ep.to_executorch(config)
773774

774775
ep.dump_executorch_program(True)
775776

exir/tests/test_passes.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,7 +1085,16 @@ def forward(self) -> torch.Tensor:
10851085
self.assertEqual(ep.graph_signature.input_specs[1].arg.name, "b_a")
10861086

10871087
# Validate that the program successfully passes validation to executorch:
1088-
edge.to_executorch()
1088+
1089+
# The test fails when do_quant_fusion_and_const_prop=True, but it is not related to
1090+
# the pass, but rather that memory planning fails (AssertionError: graph_output_allocated not set)
1091+
# when a graph has no user inputs and no operations. We can construct a failure case
1092+
# even with do_quant_fusion_and_const_prop = False by changing the forward method in NoUserInputs
1093+
# to just return self.a
1094+
config = exir.ExecutorchBackendConfig(
1095+
do_quant_fusion_and_const_prop=False,
1096+
)
1097+
edge.to_executorch(config)
10891098

10901099
def test_constant_prop_pass_for_parameter(self) -> None:
10911100
def count_additions(gm: torch.fx.GraphModule) -> int:
@@ -1279,6 +1288,7 @@ class Module(torch.nn.Module):
12791288
def __init__(self):
12801289
super().__init__()
12811290
self.linear = torch.nn.Linear(3, 3)
1291+
self.w = torch.randn(3, 3)
12821292

12831293
def t(self, val):
12841294
return val + 1
@@ -1293,8 +1303,9 @@ def false_fn(self, val):
12931303
return self.linear(val) - self.f(val)
12941304

12951305
def forward(self, pred, x):
1306+
out = torch.nn.functional.linear(x, self.w.to(torch.float16).to(torch.float32))
12961307
return torch.ops.higher_order.cond(
1297-
pred, self.true_fn, self.false_fn, [x]
1308+
pred, self.true_fn, self.false_fn, [out]
12981309
)
12991310

13001311
mod = Module()
@@ -1304,14 +1315,41 @@ def forward(self, pred, x):
13041315
export(mod, (pred, x), strict=True),
13051316
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
13061317
)
1307-
error_msg = r"constant_prop_pass for control flow is not supported yet."
1308-
1309-
# TODO(chenlai): enable constant prop pass for control flow
1310-
with self.assertRaisesRegex(
1311-
RuntimeError,
1312-
error_msg,
1313-
):
1314-
_ = constant_prop_pass(edge.exported_program())
1318+
expected_out = edge.exported_program().module()(pred, x)
1319+
1320+
warn_log = "constant_prop_pass does not constant propagate in control flow modules"
1321+
with self.assertLogs(level="WARNING") as log:
1322+
program = constant_prop_pass(edge.exported_program())
1323+
self.assertIn(warn_log, log.output[0])
1324+
1325+
out = program.module()(pred, x)
1326+
self.assertTrue(torch.allclose(expected_out, out))
1327+
1328+
# dtype casts in parent module are const propagated
1329+
FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant").run(program.graph_module.code)
1330+
1331+
def test_constant_prop_pass_quant_primitives(self) -> None:
1332+
class M(torch.nn.Module):
1333+
def __init__(self):
1334+
super().__init__()
1335+
self.w_int = torch.ones(3, 3, dtype=torch.int8)
1336+
self.w_scale = 3.0
1337+
self.w_zero_point = 3
1338+
1339+
def forward(self, x):
1340+
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1341+
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8)
1342+
return torch.nn.functional.linear(x, w_dq)
1343+
1344+
mod = M()
1345+
x = torch.randn([3])
1346+
mod(x)
1347+
edge = to_edge(
1348+
export(mod, (x,), strict=True),
1349+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1350+
)
1351+
constant_prop_pass(edge.exported_program())
1352+
FileCheck().check("executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default").run(edge.exported_program().graph_module.code)
13151353

13161354
def test_mutable_buffers(self) -> None:
13171355
def count_copies(gm: torch.fx.GraphModule) -> int:

0 commit comments

Comments
 (0)