Skip to content

Commit 97a61f4

Browse files
authored
Introduce apply_torch_ops_aten_passes to test mul fusion e2e
Differential Revision: D76469613 Pull Request resolved: #11741
1 parent 378f062 commit 97a61f4

File tree

3 files changed

+49
-36
lines changed

3 files changed

+49
-36
lines changed

backends/cadence/aot/compiler.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -262,14 +261,20 @@ def export_to_edge(
262261
inputs: tuple[object, ...],
263262
dump_graphs: bool = False,
264263
constant_methods: Optional[dict[str, object]] = None,
264+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
265265
) -> EdgeProgramManager:
266266
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
267267

268268
# Export the model into an ExportedProgram.
269269
expo_program = trace(model, inputs)
270270

271+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
272+
expo_program = apply_torch_ops_passes(expo_program)
273+
271274
# Lower the model to edge IR.
272-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
275+
edge_prog_manager = _lower_ep_to_edge(
276+
expo_program, dump_graphs, constant_methods, core_aten_exceptions
277+
)
273278

274279
return edge_prog_manager
275280

@@ -311,14 +316,7 @@ def _lower_ep_to_cadence(
311316
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
312317
"""
313318
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
314-
cadence_passes = get_cadence_passes(opt_level)
315-
316-
# Run a couple required passes for quant/dequant ops
317-
cadence_prog_manager = edge_prog_manager.transform(
318-
cast(
319-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
320-
)
321-
)
319+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
322320
return cadence_prog_manager
323321

324322

@@ -329,14 +327,7 @@ def export_to_cadence(
329327
opt_level: int = 1,
330328
) -> EdgeProgramManager:
331329
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
332-
cadence_passes = get_cadence_passes(opt_level)
333-
334-
# Run a couple required passes for quant/dequant ops
335-
cadence_prog_manager = edge_prog_manager.transform(
336-
cast(
337-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
338-
)
339-
)
330+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
340331
return cadence_prog_manager
341332

342333

@@ -373,15 +364,8 @@ def export_to_executorch_gen_etrecord(
373364
memory_config: Optional[MemoryConfig] = None,
374365
dump_graphs: bool = False,
375366
) -> ExecutorchProgramManager:
376-
cadence_passes = get_cadence_passes(opt_level)
377367
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
378-
379-
# Run a couple required passes for quant/dequant ops
380-
cadence_prog_manager = edge_prog_manager.transform(
381-
cast(
382-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
383-
)
384-
)
368+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
385369

386370
# Print some information to terminal
387371
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
# TODO(T230417247): Use PassResult which is currently ignored.
129+
PassManager(aten_passes)(expo_program.graph_module)
130+
return expo_program

0 commit comments

Comments
 (0)