Skip to content

Commit ccc998e

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Introduce apply_torch_ops_aten_passes to test mul fusion e2e (#11741)
Summary: Pull Request resolved: #11741 This diff 1) extends export_to_edge() with apply_prelowering_passes() call We need this to grab the actual mul argument value and fusing in later passes. Otherwise, the constant gets lifted by _lower_ep_to_edge() 2) implements e2e mul fusion test. We test that both ReplaceMulTensorWithMulAndFullOpsPass() and FuseMulTensorIntoQuantPass() passes get applied correctly and which results into mul.Tensor being removed completely. Reviewed By: hsharma35 Differential Revision: D76469613
1 parent 3b1c7fd commit ccc998e

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

backends/cadence/aot/compiler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from torch.export.exported_program import ExportedProgram
4242
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4343

44-
from .passes import get_cadence_passes
44+
from .passes import get_edge_passes
45+
from .passes import apply_torch_ops_passes
4546

4647
from .utils import print_ops_info
4748

@@ -265,6 +266,9 @@ def export_to_edge(
265266
# Export the model into an ExportedProgram.
266267
expo_program = trace(model, inputs)
267268

269+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
270+
apply_torch_ops_passes(expo_program)
271+
268272
# Lower the model to edge IR.
269273
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
270274

@@ -306,7 +310,7 @@ def _lower_ep_to_cadence(
306310
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
307311
"""
308312
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
309-
cadence_passes = get_cadence_passes(opt_level)
313+
cadence_passes = get_edge_passes(opt_level)
310314

311315
# Run a couple required passes for quant/dequant ops
312316
cadence_prog_manager = edge_prog_manager.transform(
@@ -324,7 +328,7 @@ def export_to_cadence(
324328
opt_level: int = 1,
325329
) -> EdgeProgramManager:
326330
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
327-
cadence_passes = get_cadence_passes(opt_level)
331+
cadence_passes = get_edge_passes(opt_level)
328332

329333
# Run a couple required passes for quant/dequant ops
330334
cadence_prog_manager = edge_prog_manager.transform(
@@ -368,7 +372,7 @@ def export_to_executorch_gen_etrecord(
368372
memory_config: Optional[MemoryConfig] = None,
369373
dump_graphs: bool = False,
370374
) -> ExecutorchProgramManager:
371-
cadence_passes = get_cadence_passes(opt_level)
375+
cadence_passes = get_edge_passes(opt_level)
372376
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
373377

374378
# Run a couple required passes for quant/dequant ops

backends/cadence/aot/passes.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,17 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
3336
from executorch.exir.pass_base import ExportPass, PassResult
3437
from executorch.exir.pass_manager import PassManager, PassType
3538
from executorch.exir.passes import dead_code_elimination_pass
3639
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3740
from executorch.exir.passes.spec_prop_pass import SpecPropPass
41+
from torch.export.exported_program import ExportedProgram
3842

3943

4044
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,7 +93,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8993
return pytree.tree_flatten(passes)[0]
9094

9195

92-
def get_cadence_passes(
96+
def get_edge_passes(
9397
opt_level: int,
9498
) -> List[Optional[PassResult]]:
9599
passes = get_passes_in_default_order()
@@ -100,3 +104,14 @@ def get_cadence_passes(
100104
for filtered_pass in list(filter(pass_filter, passes))
101105
]
102106
return filtered_passes
107+
108+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> None:
109+
"""
110+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
111+
expo_program is expected to be the output of the torch.export.export().
112+
"""
113+
114+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
115+
ReplaceMulTensorWithMulAndFullOpsPass()
116+
]
117+
PassManager(aten_passes)(expo_program.graph_module)

0 commit comments

Comments
 (0)