Skip to content

Commit 0765b24

Browse files
committed
up
1 parent c0a2cd0 commit 0765b24

File tree

6 files changed

+54
-25
lines changed

6 files changed

+54
-25
lines changed

exir/passes/constant_prop_pass.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
# pyre-unsafe
88

9+
import logging
910
from collections import OrderedDict
1011
from typing import cast, Mapping, Optional
11-
import logging
1212

1313
import torch
1414
from executorch.exir.dialects._ops import ops as exir_ops
@@ -30,6 +30,10 @@
3030
# Propagating aten.full can significantly increase compiled model size.
3131
_DEFAULT_SKIP_TARGETS = {exir_ops.edge.aten.full.default}
3232

33+
# Skipping transpose/permute for now because https://github.com/pytorch/executorch/issues/10499
34+
_DEFAULT_SKIP_TARGETS.add(exir_ops.edge.transpose.int)
35+
_DEFAULT_SKIP_TARGETS.add(exir_ops.edge.permute.default)
36+
3337
# Do not const prop quantization primitives
3438
_QDQ_OPS = [
3539
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
@@ -42,7 +46,8 @@
4246
exir_ops.edge.quantized_decomposed.choose_qparams.tensor,
4347
]
4448
try:
45-
import torchao # noqa: F401
49+
import torchao # noqa: F401
50+
4651
_QDQ_OPS.extend(
4752
[
4853
exir_ops.edge.torchao.dequantize_affine.default,
@@ -66,6 +71,7 @@
6671
torch.layout,
6772
)
6873

74+
6975
def is_const(
7076
arg,
7177
exported_program: ExportedProgram,
@@ -333,7 +339,9 @@ def constant_prop_pass(
333339
if node.target == torch.ops.higher_order.cond
334340
]
335341
if len(has_control_flow) > 0:
336-
logging.warning("constant_prop_pass does not constant propagate in control flow modules")
342+
logging.warning(
343+
"constant_prop_pass does not constant propagate in control flow modules"
344+
)
337345

338346
const_node_to_tensor = get_propagated_const_tensor_dict(
339347
exported_program, custom_skip_targets

exir/passes/quant_fusion_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
import torch
88
from executorch.exir.dialects._ops import ops as exir_ops
99
from executorch.exir.pass_base import ExportPass
10+
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
11+
from torch.export import ExportedProgram
1012
from torch.fx import GraphModule, subgraph_rewriter
1113
from torch.fx.passes.infra.pass_base import PassResult
1214
from torch.utils import _pytree as pytree
13-
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
14-
from torch.export import ExportedProgram
1515

1616
from ._quant_patterns_and_replacements import get_quant_patterns_and_replacements
1717

@@ -147,7 +147,7 @@ def quant_fusion_and_const_prop_pass(program: ExportedProgram) -> ExportedProgra
147147
gm = program.graph_module
148148
gm_res = QuantFusionPass(_fix_node_meta_val=True)(gm)
149149
gm = gm_res.graph_module
150-
150+
151151
# Do const prop pass to remove packing/dtype conversion ops
152152
program = constant_prop_pass(program)
153153
return program

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ def to_executorch(
15321532
raise Exception(
15331533
"Cannot run do_quant_fusion_and_const_prop on a graph with a backward signature intended for on-device training."
15341534
" Please set do_quant_fusion_and_const_prop to False in the ExecutorchBackendConfig."
1535-
)
1535+
)
15361536
program = quant_fusion_and_const_prop_pass(program)
15371537
program = weights_to_outputs_pass(program)
15381538
program = unsafe_remove_auto_functionalized_pass(program)

exir/tests/test_passes.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,7 +1294,9 @@ def false_fn(self, val):
12941294
return self.linear(val) - self.f(val)
12951295

12961296
def forward(self, pred, x):
1297-
out = torch.nn.functional.linear(x, self.w.to(torch.float16).to(torch.float32))
1297+
out = torch.nn.functional.linear(
1298+
x, self.w.to(torch.float16).to(torch.float32)
1299+
)
12981300
return torch.ops.higher_order.cond(
12991301
pred, self.true_fn, self.false_fn, [out]
13001302
)
@@ -1308,7 +1310,9 @@ def forward(self, pred, x):
13081310
)
13091311
expected_out = edge.exported_program().module()(pred, x)
13101312

1311-
warn_log = "constant_prop_pass does not constant propagate in control flow modules"
1313+
warn_log = (
1314+
"constant_prop_pass does not constant propagate in control flow modules"
1315+
)
13121316
with self.assertLogs(level="WARNING") as log:
13131317
program = constant_prop_pass(edge.exported_program())
13141318
self.assertIn(warn_log, log.output[0])
@@ -1317,8 +1321,10 @@ def forward(self, pred, x):
13171321
self.assertTrue(torch.allclose(expected_out, out))
13181322

13191323
# dtype casts in parent module are const propagated
1320-
FileCheck().check("executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant").run(program.graph_module.code)
1321-
1324+
FileCheck().check(
1325+
"executorch_exir_dialects_edge__ops_aten_mm_default(x, _prop_tensor_constant"
1326+
).run(program.graph_module.code)
1327+
13221328
def test_constant_prop_pass_quant_primitives(self) -> None:
13231329
class M(torch.nn.Module):
13241330
def __init__(self):
@@ -1329,9 +1335,10 @@ def __init__(self):
13291335

13301336
def forward(self, x):
13311337
w_dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
1332-
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8)
1338+
self.w_int, self.w_scale, self.w_zero_point, -127, 128, torch.int8
1339+
)
13331340
return torch.nn.functional.linear(x, w_dq)
1334-
1341+
13351342
mod = M()
13361343
x = torch.randn([3])
13371344
mod(x)
@@ -1340,7 +1347,9 @@ def forward(self, x):
13401347
compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
13411348
)
13421349
constant_prop_pass(edge.exported_program())
1343-
FileCheck().check("executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default").run(edge.exported_program().graph_module.code)
1350+
FileCheck().check(
1351+
"executorch_exir_dialects_edge__ops_quantized_decomposed_dequantize_per_tensor_default"
1352+
).run(edge.exported_program().graph_module.code)
13441353

13451354
def test_mutable_buffers(self) -> None:
13461355
def count_copies(gm: torch.fx.GraphModule) -> int:

exir/tests/test_quant_fusion_pass.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77
# pyre-strict
88

9+
import copy
910
import unittest
1011

1112
import torch
1213
from executorch import exir
1314
from executorch.exir import EdgeCompileConfig, to_edge
1415
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
15-
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass, quant_fusion_and_const_prop_pass
16+
from executorch.exir.passes.quant_fusion_pass import (
17+
quant_fusion_and_const_prop_pass,
18+
QuantFusionPass,
19+
)
1620
from executorch.exir.tests.common import register_additional_test_aten_ops
1721
from torch.ao.quantization import ( # @manual
1822
float_qparams_weight_only_qconfig,
@@ -33,7 +37,7 @@
3337
from torch.testing import FileCheck
3438
from torchao.quantization.granularity import PerAxis, PerGroup
3539
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
36-
import copy
40+
3741

3842
class TestQuantFusionPass(unittest.TestCase):
3943
@classmethod
@@ -438,9 +442,13 @@ def _test_embedding_torchao(
438442

439443
# After pass, we see packing op and quantized embedding op, but no torchao dequantize op
440444
FileCheck().check_count(
441-
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default", 1 if bit_width < 8 else 0, exactly=True
445+
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
446+
1 if bit_width < 8 else 0,
447+
exactly=True,
442448
).check_count(
443-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
449+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
450+
1,
451+
exactly=True,
444452
).check_not(
445453
"executorch_exir_dialects_edge__ops_torchao_dequantize_affine_default"
446454
).run(
@@ -451,7 +459,9 @@ def _test_embedding_torchao(
451459

452460
# After constant prop, we see quantized embedding op, but no packing op
453461
FileCheck().check_count(
454-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
462+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
463+
1,
464+
exactly=True,
455465
).check_not(
456466
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
457467
).run(
@@ -463,13 +473,14 @@ def _test_embedding_torchao(
463473
self.assertTrue(torch.allclose(expected_outputs, actual_outputs))
464474

465475
# Can lower to executorch
466-
exec_prog = m.to_executorch() # noqa
467-
476+
exec_prog = m.to_executorch() # noqa
468477

469478
# Alternative flow 2 using quant_fusion_pass on exported program
470479
quant_fusion_and_const_prop_pass(m_copy.exported_program())
471480
FileCheck().check_count(
472-
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}", 1, exactly=True,
481+
f"executorch_exir_dialects_edge__ops_quantized_decomposed_embedding_{embedding_suffix}",
482+
1,
483+
exactly=True,
473484
).check_not(
474485
"executorch_exir_dialects_edge__ops_quant_fusion__pack_embedding_weight_default",
475486
).run(
@@ -480,4 +491,4 @@ def _test_embedding_torchao(
480491
self.assertTrue(torch.allclose(expected_outputs, actual_outputs2))
481492

482493
# Can lower to executorch
483-
exec_prog2 = m_copy.to_executorch() # noqa
494+
exec_prog2 = m_copy.to_executorch() # noqa

extension/llm/export/builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
2121
DuplicateDynamicQuantChainPass,
2222
)
23+
from executorch.backends.xnnpack._passes.convert_to_linear import ConvertToLinearPass
2324
from executorch.exir import EdgeProgramManager, to_edge_transform_and_lower
2425
from executorch.exir.backend.partitioner import Partitioner
2526

@@ -508,9 +509,9 @@ def to_executorch(
508509
# If there are Linear operations left in the graph, let's execute
509510
# them with the optimized op_linear rather than materializing a
510511
# transpose followed by a regular op_mm.
511-
# Disabling because ConvertToLinearPass is not a sound pass:
512+
# TODO: ConvertToLinearPass is not a sound pass and we should fix it
512513
# https://github.com/pytorch/executorch/issues/10499
513-
# ConvertToLinearPass(),
514+
ConvertToLinearPass(),
514515
]
515516
if passes:
516517
# pyre-fixme[6]: In call `list.extend`, for 1st positional argument,

0 commit comments

Comments
 (0)