Skip to content

Commit 7394f7b

Browse files
metascroyfacebook-github-bot
authored andcommitted
Enable quant fusion and const propagation by default (pytorch#10394)
Summary: This diff enables quant fusion and constant propagation by default in ExecuTorch. It occurs after all to_edge passes, but before memory planning. Differential Revision: D73513516
1 parent 2553d99 commit 7394f7b

File tree

8 files changed

+114
-26
lines changed

8 files changed

+114
-26
lines changed

exir/capture/_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
102102
# serialized in the PTE file. Its value is ignored if mutable buffers are not
103103
# memory planned as the names must be serialized in that case.
104104
emit_mutable_buffer_names: bool = False
105+
106+
# If set to true, we run quant fusion and constant propagation passes
107+
do_quant_fusion_and_const_prop: bool = True

exir/passes/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ python_library(
154154
"//caffe2:torch",
155155
"//executorch/exir:pass_base",
156156
"//executorch/exir/dialects:lib",
157+
"//pytorch/ao:torchao",
158+
"//executorch/exir/passes:constant_prop_pass",
157159
],
158160
)
159161

exir/passes/constant_prop_pass.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from collections import OrderedDict
1010
from typing import cast, Mapping, Optional
11+
import logging
1112

1213
import torch
1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -29,6 +30,31 @@
2930
# Propagating aten.full can significantly increase compiled model size.
3031
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3132

33+
# Do not const prop quantization primitives
34+
_QDQ_OPS = [
35+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
36+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
37+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
38+
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
39+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
41+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
42+
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
43+
]
44+
try:
45+
import torchao # noqa: F401
46+
_QDQ_OPS.extend(
47+
[
48+
exir_ops.edge.torchao.dequantize_affine.default,
49+
exir_ops.edge.torchao.quantize_affine.default,
50+
exir_ops.edge.torchao.choose_qparams_affine.default,
51+
]
52+
)
53+
except ImportError:
54+
pass
55+
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))
56+
57+
3258
_PRIMITIVE_TYPES = (
3359
float,
3460
int,
@@ -40,7 +66,6 @@
4066
torch.layout,
4167
)
4268

43-
4469
def is_const(
4570
arg,
4671
exported_program: ExportedProgram,
@@ -308,7 +333,7 @@ def constant_prop_pass(
308333
if node.target == torch.ops.higher_order.cond
309334
]
310335
if len(has_control_flow) > 0:
311-
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
336+
logging.warning("constant_prop_pass does not constant propagate in control flow modules")
312337

313338
const_node_to_tensor = get_propagated_const_tensor_dict(
314339
exported_program, custom_skip_targets

exir/passes/quant_fusion_pass.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from torch.fx import GraphModule, subgraph_rewriter
1111
from torch.fx.passes.infra.pass_base import PassResult
1212
from torch.utils import _pytree as pytree
13+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
14+
from torch.export import ExportedProgram
1315

1416
from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
1517

@@ -139,3 +141,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
139141
graph_module.graph.lint()
140142
graph_module.graph.eliminate_dead_code()
141143
return PassResult(graph_module, True)
144+
145+
146+
def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
147+
gm = program.graph_module
148+
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
149+
gm = gm_res.graph_module
150+
program.validate()
151+
152+
# Do const prop pass to remove packing/dtype conversion ops
153+
program = constant_prop_pass(program)
154+
return program

exir/program/_program.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from executorch.exir.passes.normalize_view_copy_base_pass import (
5353
NormalizeViewCopyBasePass,
5454
)
55+
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
5556
from executorch.exir.passes.remove_graph_asserts_pass import (
5657
RemoveGraphAssertsPass,
5758
RemoveNonCoreAtenOpGraphAssertsPass,
@@ -1526,9 +1527,12 @@ def to_executorch(
15261527
after it has been transformed to the ExecuTorch backend.
15271528
"""
15281529
config = config if config else ExecutorchBackendConfig()
1529-
15301530
execution_programs: Dict[str, ExportedProgram] = {}
15311531
for name, program in self._edge_programs.items():
1532+
# Do constant propagation. This is needed for some quant fusion
1533+
# passes to work correctly
1534+
if config.do_quant_fusion_and_const_prop:
1535+
program = quant_fusion_and_const_prop_pass(program)
15321536
program = weights_to_outputs_pass(program)
15331537
program = unsafe_remove_auto_functionalized_pass(program)
15341538
gm, new_signature = insert_write_back_for_buffers_pass(program)

exir/tests/test_passes.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,6 +1279,7 @@ class Module(torch.nn.Module):
12791279
def __init__(self):
12801280
super().__init__()
12811281
self.linear = torch.nn.Linear(3, 3)
1282+
self.w = torch.randn(3, 3)
12821283

12831284
def t(self, val):
12841285
return val + 1
@@ -1293,8 +1294,9 @@ def false_fn(self, val):
12931294
return self.linear(val) - self.f(val)
12941295

12951296
def forward(self, pred, x):
1297+
out = torch.nn.functional.linear(x, self.w.to(torch.float16).to(torch.float32))
12961298
return torch.ops.higher_order.cond(
1297-
pred, self.true_fn, self.false_fn, [x]
1299+
pred, self.true_fn, self.false_fn, [out]
12981300
)
12991301

13001302
mod = Module()
@@ -1304,14 +1306,41 @@ def forward(self, pred, x):
13041306
export(mod, (pred, x), strict=True),
13051307
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
13061308
)
1307-
error_msg = r"constant_prop_pass for control flow is not supported yet."
1308-
1309-
# TODO(chenlai): enable constant prop pass for control flow
1310-
with self.assertRaisesRegex(
1311-
RuntimeError,
1312-
error_msg,
1313-
):
1314-
_ = constant_prop_pass(edge.exported_program())
1309+
expected_out = edge.exported_program().module()(pred, x)
1310+
1311+
warn_log = "constant_prop_pass does not constant propagate in control flow modules"
1312+
with self.assertLogs(level="WARNING") as log:
1313+
program = constant_prop_pass(edge.exported_program())
1314+
self.assertIn(warn_log, log.output[0])
1315+
1316+
out = program.module()(pred, x)
1317+
self.assertTrue(torch.allclose(expected_out, out))
1318+
1319+
# dtype casts in parent module are const propagated
1320+
FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant").run(program.graph_module.code)
1321+
1322+
def test_constant_prop_pass_quant_primitives(self) -> None:
1323+
class M(torch.nn.Module):
1324+
def __init__(self):
1325+
super().__init__()
1326+
self.w_int = torch.ones(3, 3, dtype=torch.int8)
1327+
self.w_scale = 3.0
1328+
self.w_zero_point = 3
1329+
1330+
def forward(self, x):
1331+
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1332+
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8)
1333+
return torch.nn.functional.linear(x, w_dq)
1334+
1335+
mod = M()
1336+
x = torch.randn([3])
1337+
mod(x)
1338+
edge = to_edge(
1339+
export(mod, (x,), strict=True),
1340+
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
1341+
)
1342+
constant_prop_pass(edge.exported_program())
1343+
FileCheck().check("executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default").run(edge.exported_program().graph_module.code)
13151344

13161345
def test_mutable_buffers(self) -> None:
13171346
def count_copies(gm: torch.fx.GraphModule) -> int:

exir/tests/test_quant_fusion_pass.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from executorch import exir
1313
from executorch.exir import EdgeCompileConfig, to_edge
1414
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
15-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
15+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass, quant_fusion_and_const_prop_pass
1616
from executorch.exir.tests.common import register_additional_test_aten_ops
1717
from torch.ao.quantization import ( # @manual
1818
float_qparams_weight_only_qconfig,
@@ -33,7 +33,7 @@
3333
from torch.testing import FileCheck
3434
from torchao.quantization.granularity import PerAxis, PerGroup
3535
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
36-
36+
import copy
3737

3838
class TestQuantFusionPass(unittest.TestCase):
3939
@classmethod
@@ -419,6 +419,7 @@ def _test_embedding_torchao(
419419
m = to_edge(
420420
export(model, example_inputs, strict=True), compile_config=compile_config
421421
)
422+
m_copy = copy.deepcopy(m)
422423

423424
# Before pass, we see torchao dequantize and embedding ops
424425
FileCheck().check_count(
@@ -437,13 +438,9 @@ def _test_embedding_torchao(
437438

438439
# After pass, we see packing op and quantized embedding op, but no torchao dequantize op
439440
FileCheck().check_count(
440-
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
441-
1 if bit_width < 8 else 0,
442-
exactly=True,
441+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", 1 if bit_width < 8 else 0, exactly=True
443442
).check_count(
444-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
445-
1,
446-
exactly=True,
443+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
447444
).check_not(
448445
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
449446
).run(
@@ -454,9 +451,7 @@ def _test_embedding_torchao(
454451

455452
# After constant prop, we see quantized embedding op, but no packing op
456453
FileCheck().check_count(
457-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
458-
1,
459-
exactly=True,
454+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
460455
).check_not(
461456
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
462457
).run(
@@ -468,4 +463,21 @@ def _test_embedding_torchao(
468463
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
469464

470465
# Can lower to executorch
471-
exec_prog = m.to_executorch() # noqa: F841
466+
exec_prog = m.to_executorch() # noqa
467+
468+
469+
# Alternative flow 2 using quant_fusion_pass on exported program
470+
quant_fusion_and_const_prop_pass(m_copy.exported_program())
471+
FileCheck().check_count(
472+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
473+
).check_not(
474+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
475+
).run(
476+
m_copy.exported_program().graph_module.code
477+
)
478+
479+
actual_outputs2 = m_copy.exported_program().module()(*example_inputs)
480+
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))
481+
482+
# Can lower to executorch
483+
exec_prog2 = m_copy.to_executorch() # noqa

extension/llm/export/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,8 @@ def to_executorch(
508508
# If there are Linear operations left in the graph, let's execute
509509
# them with the optimized op_linear rather than materializing a
510510
# transpose followed by a regular op_mm.
511-
ConvertToLinearPass(),
512-
QuantFusionPass(),
511+
# ConvertToLinearPass(),
512+
# QuantFusionPass(),
513513
]
514514
if passes:
515515
# pyre-fixme[6]: In call `list.extend`, for 1st positional argument,

0 commit comments

Comments
 (0)