Skip to content

Commit 9c8ae82

Browse files
metascroyfacebook-github-bot
authored andcommitted
Enable quant fusion and const propagation by default (pytorch#10394)
Summary: This diff enables quant fusion and constant propagation by default in ExecuTorch. It occurs after all to_edge passes, but before memory planning. Differential Revision: D73513516
1 parent fa89efa commit 9c8ae82

File tree

10 files changed

+154
-48
lines changed

10 files changed

+154
-48
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
102102
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103103
# memory planned as the names must be serialized in that case.
104104
emit_mutable_buffer_names: bool = False
105+
106+
# If set to true, we run quant fusion and constant propagation passes
107+
do_quant_fusion_and_const_prop: bool = True

exir/emit/test/test_emit.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
431431
.executorch_program
432432
)
433433
# The value for beta should appear before alpha
434-
self.assertEqual(program.execution_plan[0].values[12].val, Int(3))
435-
self.assertEqual(program.execution_plan[0].values[13].val, Int(2))
434+
self.assertEqual(program.execution_plan[0].values[4].val, Int(3))
435+
self.assertEqual(program.execution_plan[0].values[5].val, Int(2))
436436

437437
def test_kwargs2(self) -> None:
438438
"""Tests that the kwargs are placed in the order specified by
@@ -451,10 +451,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
451451
to_edge(export(f, (x,), strict=True)).to_executorch().executorch_program
452452
)
453453
# The value for right should appear before side
454-
self.assertEqual(program.execution_plan[0].values[6].val, Bool(False))
455-
self.assertEqual(program.execution_plan[0].values[7].val, Bool(True))
456-
self.assertEqual(program.execution_plan[0].values[8].val, String("right"))
457-
self.assertEqual(program.execution_plan[0].values[9].val, Null())
454+
self.assertEqual(program.execution_plan[0].values[3].val, Bool(False))
455+
self.assertEqual(program.execution_plan[0].values[4].val, Bool(True))
456+
self.assertEqual(program.execution_plan[0].values[5].val, String("right"))
457+
self.assertEqual(program.execution_plan[0].values[6].val, Null())
458458

459459
def _assertCallLength(self, program: Program, idx: int, expected_len: int) -> None:
460460
instr_args = program.execution_plan[0].chains[0].instructions[idx].instr_args
@@ -532,24 +532,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
532532
# Check the mul operator's stack trace contains f -> g -> h
533533
self.assertTrue(
534534
"return torch.mul(x, torch.randn(3, 2))"
535-
in program.execution_plan[0].chains[0].stacktrace[1].items[-1].context
535+
in program.execution_plan[0].chains[0].stacktrace[0].items[-1].context
536536
)
537537
self.assertEqual(
538-
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "f"
538+
program.execution_plan[0].chains[0].stacktrace[0].items[-1].name, "f"
539539
)
540540
self.assertEqual(
541-
program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "g"
541+
program.execution_plan[0].chains[0].stacktrace[0].items[-2].name, "g"
542542
)
543543
self.assertEqual(
544-
program.execution_plan[0].chains[0].stacktrace[1].items[-3].name, "forward"
544+
program.execution_plan[0].chains[0].stacktrace[0].items[-3].name, "forward"
545545
)
546546

547547
# Check the sin operator's stack trace contains g -> h
548548
self.assertEqual(
549-
program.execution_plan[0].chains[0].stacktrace[2].items[-1].name, "g"
549+
program.execution_plan[0].chains[0].stacktrace[1].items[-1].name, "g"
550550
)
551551
self.assertEqual(
552-
program.execution_plan[0].chains[0].stacktrace[2].items[-2].name, "forward"
552+
program.execution_plan[0].chains[0].stacktrace[1].items[-2].name, "forward"
553553
)
554554

555555
def test_stacktrace_off(self) -> None:
@@ -878,10 +878,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
878878
.executorch_program.execution_plan[0]
879879
.non_const_buffer_sizes
880880
)
881-
881+
882+
config = ExecutorchBackendConfig(
883+
do_quant_fusion_and_const_prop=False,
884+
)
882885
edge_program_manager = to_edge(export(f, (torch.ones(3, 2),), strict=True))
883886
non_const_buffer_size_without_const_prop_pass = (
884-
edge_program_manager.to_executorch()
887+
edge_program_manager.to_executorch(config)
885888
.executorch_program.execution_plan[0]
886889
.non_const_buffer_sizes
887890
)
@@ -1510,7 +1513,12 @@ def forward(self, x):
15101513
self.assertEqual(model.W1.untyped_storage().nbytes(), 8)
15111514
self.assertEqual(model.W2.nbytes, 4)
15121515
self.assertEqual(model.W2.untyped_storage().nbytes(), 8)
1513-
program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch()
1516+
1517+
# Without this, the views get
1518+
config = exir.ExecutorchBackendConfig(
1519+
do_quant_fusion_and_const_prop=False,
1520+
)
1521+
program = to_edge(export(model, (torch.ones(1),), strict=True)).to_executorch(config)
15141522

15151523
program = program._emitter_output.program
15161524
# each emitted weight is not a view
@@ -1531,7 +1539,10 @@ def forward(self, x):
15311539
program = program._emitter_output.program
15321540
# confirm that the buffer was emitted
15331541
self.assertEqual(len(program.constant_buffer), 2)
1534-
self.assertEqual(len(program.constant_buffer[1].storage), 8)
1542+
1543+
# executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default
1544+
# converts the buffer from i64 to fp32 (4 bytes), which gets const propagated
1545+
self.assertEqual(len(program.constant_buffer[1].storage), 4)
15351546

15361547
def test_emit_lifted_tensor_constant(self) -> None:
15371548
class LiftedTensorConstants(nn.Module):
@@ -1544,7 +1555,7 @@ def forward(self, x):
15441555

15451556
model = LiftedTensorConstants()
15461557
# Specify that we want to move non-lifted constants to external file
1547-
et_cfg = ExecutorchBackendConfig(external_constants=True)
1558+
et_cfg = ExecutorchBackendConfig(external_constants=True, do_quant_fusion_and_const_prop=False)
15481559
program = to_edge(
15491560
export(model, (torch.ones(3, 2),), strict=True)
15501561
).to_executorch(et_cfg)
@@ -1566,7 +1577,7 @@ def forward(self, x):
15661577

15671578
model = LiftedConstants()
15681579
# Specify that we want to move non-lifted constants to external file
1569-
et_cfg = ExecutorchBackendConfig(external_constants=True)
1580+
et_cfg = ExecutorchBackendConfig(external_constants=True, do_quant_fusion_and_const_prop=False)
15701581
program = to_edge(
15711582
export(model, (torch.ones(3, 2),), strict=True)
15721583
).to_executorch(et_cfg)
@@ -1658,7 +1669,10 @@ def forward(self, x):
16581669
model = to_edge(export(InfinityMaskModel(), (torch.randn(2, 2),), strict=True))
16591670

16601671
# Confirm that we can serialize the model with infinity in it.
1661-
model = model.to_executorch()
1672+
config = ExecutorchBackendConfig(
1673+
do_quant_fusion_and_const_prop=False,
1674+
)
1675+
model = model.to_executorch(config)
16621676

16631677
# Assert that the infinity is stored as a string "-inf".
16641678
values = model.executorch_program.execution_plan[0].values
@@ -1716,8 +1730,8 @@ def forward(self, x):
17161730
external_map = emitter_output.external_constant_map[
17171731
"_default_external_constant"
17181732
]
1719-
self.assertEqual(external_map["linear.weight"], 0)
1720-
self.assertEqual(external_map["linear.bias"], 1)
1733+
self.assertEqual(external_map["_prop_tensor_constant0"], 1)
1734+
self.assertEqual(external_map["linear.bias"], 0)
17211735

17221736
def test_delegate_deduplicate(self) -> None:
17231737
class SharedModule(torch.nn.Module):
@@ -1804,7 +1818,7 @@ def forward(self, input, label):
18041818
ep = to_edge(ep)
18051819
# Lower the graph to executorch.
18061820
ep = ep.to_executorch(
1807-
config=ExecutorchBackendConfig(external_mutable_weights=True)
1821+
config=ExecutorchBackendConfig(external_mutable_weights=True, do_quant_fusion_and_const_prop=False)
18081822
)
18091823

18101824
emitter_output = ep._emitter_output

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ python_library(
154154
"//caffe2:torch",
155155
"//executorch/exir:pass_base",
156156
"//executorch/exir/dialects:lib",
157+
"//pytorch/ao:torchao",
158+
"//executorch/exir/passes:constant_prop_pass",
157159
],
158160
)
159161

exir/passes/constant_prop_pass.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from collections import OrderedDict
1010
from typing import cast, Mapping, Optional
11+
import logging
1112

1213
import torch
1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -29,6 +30,31 @@
2930
# Propagating aten.full can significantly increase compiled model size.
3031
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3132

33+
# Do not const prop quantization primitives
34+
_QDQ_OPS = [
35+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
38+
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
39+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
41+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
42+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
43+
]
44+
try:
45+
import torchao # noqa: F401
46+
_QDQ_OPS.extend(
47+
[
48+
exir_ops.edge.torchao.dequantize_affine.default,
49+
exir_ops.edge.torchao.quantize_affine.default,
50+
exir_ops.edge.torchao.choose_qparams_affine.default,
51+
]
52+
)
53+
except ImportError:
54+
pass
55+
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))
56+
57+
3258
_PRIMITIVE_TYPES = (
3359
float,
3460
int,
@@ -40,7 +66,6 @@
4066
torch.layout,
4167
)
4268

43-
4469
def is_const(
4570
arg,
4671
exported_program: ExportedProgram,
@@ -308,7 +333,7 @@ def constant_prop_pass(
308333
if node.target == torch.ops.higher_order.cond
309334
]
310335
if len(has_control_flow) > 0:
311-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
336+
logging.warning("constant_prop_pass does not constant propagate in control flow modules")
312337

313338
const_node_to_tensor = get_propagated_const_tensor_dict(
314339
exported_program, custom_skip_targets

exir/passes/quant_fusion_pass.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch.fx import GraphModule, subgraph_rewriter
1111
from torch.fx.passes.infra.pass_base import PassResult
1212
from torch.utils import _pytree as pytree
13+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
14+
from torch.export import ExportedProgram
1315

1416
from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
1517

@@ -139,3 +141,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
139141
graph_module.graph.lint()
140142
graph_module.graph.eliminate_dead_code()
141143
return PassResult(graph_module, True)
144+
145+
146+
def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
147+
gm = program.graph_module
148+
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
149+
gm = gm_res.graph_module
150+
program.validate()
151+
152+
# Do const prop pass to remove packing/dtype conversion ops
153+
program = constant_prop_pass(program)
154+
return program

exir/program/_program.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.normalize_view_copy_base_pass import (
5353
NormalizeViewCopyBasePass,
5454
)
55+
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
5556
from executorch.exir.passes.remove_graph_asserts_pass import (
5657
RemoveGraphAssertsPass,
5758
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -1524,9 +1525,15 @@ def to_executorch(
15241525
after it has been transformed to the ExecuTorch backend.
15251526
"""
15261527
config = config if config else ExecutorchBackendConfig()
1527-
15281528
execution_programs: Dict[str, ExportedProgram] = {}
15291529
for name, program in self._edge_programs.items():
1530+
if config.do_quant_fusion_and_const_prop:
1531+
if program.graph_signature.backward_signature is not None:
1532+
raise Exception(
1533+
"Cannot run do_quant_fusion_and_const_prop on a graph with a backward signature intended for on-device training."
1534+
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
1535+
)
1536+
program = quant_fusion_and_const_prop_pass(program)
15301537
program = weights_to_outputs_pass(program)
15311538
program = unsafe_remove_auto_functionalized_pass(program)
15321539
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_passes.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ class Module(torch.nn.Module):
12791279
def __init__(self):
12801280
super().__init__()
12811281
self.linear = torch.nn.Linear(3, 3)
1282+
self.w = torch.randn(3, 3)
12821283

12831284
def t(self, val):
12841285
return val + 1
@@ -1293,8 +1294,9 @@ def false_fn(self, val):
12931294
return self.linear(val) - self.f(val)
12941295

12951296
def forward(self, pred, x):
1297+
out = torch.nn.functional.linear(x, self.w.to(torch.float16).to(torch.float32))
12961298
return torch.ops.higher_order.cond(
1297-
pred, self.true_fn, self.false_fn, [x]
1299+
pred, self.true_fn, self.false_fn, [out]
12981300
)
12991301

13001302
mod = Module()
@@ -1304,14 +1306,41 @@ def forward(self, pred, x):
13041306
export(mod, (pred, x), strict=True),
13051307
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
13061308
)
1307-
error_msg = r"constant_prop_pass for control flow is not supported yet."
1308-
1309-
# TODO(chenlai): enable constant prop pass for control flow
1310-
with self.assertRaisesRegex(
1311-
RuntimeError,
1312-
error_msg,
1313-
):
1314-
_ = constant_prop_pass(edge.exported_program())
1309+
expected_out = edge.exported_program().module()(pred, x)
1310+
1311+
warn_log = "constant_prop_pass does not constant propagate in control flow modules"
1312+
with self.assertLogs(level="WARNING") as log:
1313+
program = constant_prop_pass(edge.exported_program())
1314+
self.assertIn(warn_log, log.output[0])
1315+
1316+
out = program.module()(pred, x)
1317+
self.assertTrue(torch.allclose(expected_out, out))
1318+
1319+
# dtype casts in parent module are const propagated
1320+
FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant").run(program.graph_module.code)
1321+
1322+
def test_constant_prop_pass_quant_primitives(self) -> None:
1323+
class M(torch.nn.Module):
1324+
def __init__(self):
1325+
super().__init__()
1326+
self.w_int = torch.ones(3, 3, dtype=torch.int8)
1327+
self.w_scale = 3.0
1328+
self.w_zero_point = 3
1329+
1330+
def forward(self, x):
1331+
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1332+
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8)
1333+
return torch.nn.functional.linear(x, w_dq)
1334+
1335+
mod = M()
1336+
x = torch.randn([3])
1337+
mod(x)
1338+
edge = to_edge(
1339+
export(mod, (x,), strict=True),
1340+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1341+
)
1342+
constant_prop_pass(edge.exported_program())
1343+
FileCheck().check("executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default").run(edge.exported_program().graph_module.code)
13151344

13161345
def test_mutable_buffers(self) -> None:
13171346
def count_copies(gm: torch.fx.GraphModule) -> int:

0 commit comments

Comments
 (0)