Skip to content

Commit 57dbe1e

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Introduce apply_torch_ops_passes to test mul fusion e2e (#11741)
Summary: This diff 1) extends export_to_edge() with apply_torch_ops_passes() call We need this to grab the actual mul argument value and fusing in later passes. Otherwise, the constant gets lifted by _lower_ep_to_edge() 2) implements e2e mul fusion test. We test that both ReplaceMulTensorWithMulAndFullOpsPass() and FuseMulTensorIntoQuantPass() passes get applied correctly and which results into mul.Tensor being removed completely. Reviewed By: hsharma35 Differential Revision: D76469613
1 parent f9a3ca8 commit 57dbe1e

File tree

4 files changed

+46
-35
lines changed

4 files changed

+46
-35
lines changed

backends/cadence/aot/compiler.py

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -265,6 +264,9 @@ def export_to_edge(
265264
# Export the model into an ExportedProgram.
266265
expo_program = trace(model, inputs)
267266

267+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
268+
expo_program = apply_torch_ops_passes(expo_program)
269+
268270
# Lower the model to edge IR.
269271
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
270272

@@ -308,14 +310,7 @@ def _lower_ep_to_cadence(
308310
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
309311
"""
310312
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
311-
cadence_passes = get_cadence_passes(opt_level)
312-
313-
# Run a couple required passes for quant/dequant ops
314-
cadence_prog_manager = edge_prog_manager.transform(
315-
cast(
316-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
317-
)
318-
)
313+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
319314
return cadence_prog_manager
320315

321316

@@ -326,14 +321,7 @@ def export_to_cadence(
326321
opt_level: int = 1,
327322
) -> EdgeProgramManager:
328323
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
329-
cadence_passes = get_cadence_passes(opt_level)
330-
331-
# Run a couple required passes for quant/dequant ops
332-
cadence_prog_manager = edge_prog_manager.transform(
333-
cast(
334-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
335-
)
336-
)
324+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
337325
return cadence_prog_manager
338326

339327

@@ -370,15 +358,8 @@ def export_to_executorch_gen_etrecord(
370358
memory_config: Optional[MemoryConfig] = None,
371359
dump_graphs: bool = False,
372360
) -> ExecutorchProgramManager:
373-
cadence_passes = get_cadence_passes(opt_level)
374361
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
375-
376-
# Run a couple required passes for quant/dequant ops
377-
cadence_prog_manager = edge_prog_manager.transform(
378-
cast(
379-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
380-
)
381-
)
362+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
382363

383364
# Print some information to terminal
384365
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
result = PassManager(aten_passes)(expo_program.graph_module)
129+
expo_program.graph_module = result.graph_module
130+
return expo_program

backends/cadence/aot/replace_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2449,4 +2449,5 @@ class CadenceReplaceOpsInGraph:
24492449
ReplaceAtenApproxGeluWithApproxGeluPass,
24502450
ReplaceSplitWithSlicePass,
24512451
ReplacePowWithMulPass,
2452+
ReplaceMulTensorWithMulAndFullOpsPass,
24522453
]

0 commit comments

Comments
 (0)