Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions exir/capture/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ class ExecutorchBackendConfig:
# serialized in the PTE file. Its value is ignored if mutable buffers are not
# memory planned as the names must be serialized in that case.
emit_mutable_buffer_names: bool = False

# If set to true, we run quant fusion and constant propagation passes
do_quant_fusion_and_const_prop: bool = False
2 changes: 2 additions & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ python_library(
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//pytorch/ao:torchao",
"//executorch/exir/passes:constant_prop_pass",
],
)

Expand Down
29 changes: 27 additions & 2 deletions exir/passes/constant_prop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from collections import OrderedDict
from typing import cast, Mapping, Optional
import logging

import torch
from executorch.exir.dialects._ops import ops as exir_ops
Expand All @@ -29,6 +30,31 @@
# Propagating aten.full can significantly increase compiled model size.
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}

# Do not const prop quantization primitives
_QDQ_OPS = [
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.convert_element_type.no_fuse,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
]
try:
import torchao # noqa: F401
_QDQ_OPS.extend(
[
exir_ops.edge.torchao.dequantize_affine.default,
exir_ops.edge.torchao.quantize_affine.default,
exir_ops.edge.torchao.choose_qparams_affine.default,
]
)
except ImportError:
pass
_DEFAULT_SKIP_TARGETS.update(set(_QDQ_OPS))


_PRIMITIVE_TYPES = (
float,
int,
Expand All @@ -40,7 +66,6 @@
torch.layout,
)


def is_const(
arg,
exported_program: ExportedProgram,
Expand Down Expand Up @@ -308,7 +333,7 @@ def constant_prop_pass(
if node.target == torch.ops.higher_order.cond
]
if len(has_control_flow) > 0:
raise RuntimeError("constant_prop_pass for control flow is not supported yet.")
logging.warning("constant_prop_pass does not constant propagate in control flow modules")

const_node_to_tensor = get_propagated_const_tensor_dict(
exported_program, custom_skip_targets
Expand Down
12 changes: 12 additions & 0 deletions exir/passes/quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch.fx import GraphModule, subgraph_rewriter
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils import _pytree as pytree
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from torch.export import ExportedProgram

from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements

Expand Down Expand Up @@ -139,3 +141,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
graph_module.graph.lint()
graph_module.graph.eliminate_dead_code()
return PassResult(graph_module, True)


def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgram:
gm = program.graph_module
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
gm = gm_res.graph_module

# Do const prop pass to remove packing/dtype conversion ops
program = constant_prop_pass(program)
return program
9 changes: 8 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from executorch.exir.passes.normalize_view_copy_base_pass import (
NormalizeViewCopyBasePass,
)
from executorch.exir.passes.quant_fusion_pass import quant_fusion_and_const_prop_pass
from executorch.exir.passes.remove_graph_asserts_pass import (
RemoveGraphAssertsPass,
RemoveNonCoreAtenOpGraphAssertsPass,
Expand Down Expand Up @@ -1524,9 +1525,15 @@ def to_executorch(
after it has been transformed to the ExecuTorch backend.
"""
config = config if config else ExecutorchBackendConfig()

execution_programs: Dict[str, ExportedProgram] = {}
for name, program in self._edge_programs.items():
if config.do_quant_fusion_and_const_prop:
if program.graph_signature.backward_signature is not None:
raise Exception(
"Cannot run do_quant_fusion_and_const_prop on a graph with a backward signature intended for on-device training."
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
)
program = quant_fusion_and_const_prop_pass(program)
program = weights_to_outputs_pass(program)
program = unsafe_remove_auto_functionalized_pass(program)
gm, new_signature = insert_write_back_for_buffers_pass(program)
Expand Down
47 changes: 38 additions & 9 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,6 +1279,7 @@ class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3, 3)
self.w = torch.randn(3, 3)

def t(self, val):
return val + 1
Expand All @@ -1293,8 +1294,9 @@ def false_fn(self, val):
return self.linear(val) - self.f(val)

def forward(self, pred, x):
out = torch.nn.functional.linear(x, self.w.to(torch.float16).to(torch.float32))
return torch.ops.higher_order.cond(
pred, self.true_fn, self.false_fn, [x]
pred, self.true_fn, self.false_fn, [out]
)

mod = Module()
Expand All @@ -1304,14 +1306,41 @@ def forward(self, pred, x):
export(mod, (pred, x), strict=True),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
error_msg = r"constant_prop_pass for control flow is not supported yet."

# TODO(chenlai): enable constant prop pass for control flow
with self.assertRaisesRegex(
RuntimeError,
error_msg,
):
_ = constant_prop_pass(edge.exported_program())
expected_out = edge.exported_program().module()(pred, x)

warn_log = "constant_prop_pass does not constant propagate in control flow modules"
with self.assertLogs(level="WARNING") as log:
program = constant_prop_pass(edge.exported_program())
self.assertIn(warn_log, log.output[0])

out = program.module()(pred, x)
self.assertTrue(torch.allclose(expected_out, out))

# dtype casts in parent module are const propagated
FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant").run(program.graph_module.code)

def test_constant_prop_pass_quant_primitives(self) -> None:
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.w_int = torch.ones(3, 3, dtype=torch.int8)
self.w_scale = 3.0
self.w_zero_point = 3

def forward(self, x):
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8)
return torch.nn.functional.linear(x, w_dq)

mod = M()
x = torch.randn([3])
mod(x)
edge = to_edge(
export(mod, (x,), strict=True),
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
)
constant_prop_pass(edge.exported_program())
FileCheck().check("executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default").run(edge.exported_program().graph_module.code)

def test_mutable_buffers(self) -> None:
def count_copies(gm: torch.fx.GraphModule) -> int:
Expand Down
36 changes: 24 additions & 12 deletions exir/tests/test_quant_fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from executorch import exir
from executorch.exir import EdgeCompileConfig, to_edge
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass, quant_fusion_and_const_prop_pass
from executorch.exir.tests.common import register_additional_test_aten_ops
from torch.ao.quantization import ( # @manual
float_qparams_weight_only_qconfig,
Expand All @@ -33,7 +33,7 @@
from torch.testing import FileCheck
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_

import copy

class TestQuantFusionPass(unittest.TestCase):
@classmethod
Expand Down Expand Up @@ -419,6 +419,7 @@ def _test_embedding_torchao(
m = to_edge(
export(model, example_inputs, strict=True), compile_config=compile_config
)
m_copy = copy.deepcopy(m)

# Before pass, we see torchao dequantize and embedding ops
FileCheck().check_count(
Expand All @@ -437,13 +438,9 @@ def _test_embedding_torchao(

# After pass, we see packing op and quantized embedding op, but no torchao dequantize op
FileCheck().check_count(
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
1 if bit_width < 8 else 0,
exactly=True,
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", 1 if bit_width < 8 else 0, exactly=True
).check_count(
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
1,
exactly=True,
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
).check_not(
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
).run(
Expand All @@ -454,9 +451,7 @@ def _test_embedding_torchao(

# After constant prop, we see quantized embedding op, but no packing op
FileCheck().check_count(
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
1,
exactly=True,
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
).check_not(
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
).run(
Expand All @@ -468,4 +463,21 @@ def _test_embedding_torchao(
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))

# Can lower to executorch
exec_prog = m.to_executorch() # noqa: F841
exec_prog = m.to_executorch() # noqa


# Alternative flow 2 using quant_fusion_pass on exported program
quant_fusion_and_const_prop_pass(m_copy.exported_program())
FileCheck().check_count(
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
).check_not(
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
).run(
m_copy.exported_program().graph_module.code
)

actual_outputs2 = m_copy.exported_program().module()(*example_inputs)
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))

# Can lower to executorch
exec_prog2 = m_copy.to_executorch() # noqa
10 changes: 6 additions & 4 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
from executorch.exir.backend.partitioner import Partitioner

Expand All @@ -29,7 +28,6 @@

from executorch.exir.pass_base import ExportPass
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass

from executorch.extension.export_util.utils import export_to_edge, save_pte_program
Expand Down Expand Up @@ -504,12 +502,15 @@ def to_executorch(
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
# QuantFusionPass is not necessary because we set do_quant_fusion_and_const_prop=True
# in ExecutorchBackendConfig
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
# Disabling because ConvertToLinearPass is not a sound pass:
# https://github.com/pytorch/executorch/issues/10499
# ConvertToLinearPass(),
]
if passes:
# pyre-fixme[6]: In call `list.extend`, for 1st positional argument,
Expand All @@ -526,6 +527,7 @@ def to_executorch(
# Optional[PassResult]]]` but got `List[Union[ConvertToLinearPass,
# QuantFusionPass]]`.
passes=to_executorch_passes,
do_quant_fusion_and_const_prop=True,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
Loading