Skip to content

Commit 9a56264

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Introduce apply_torch_ops_passes to test mul fusion e2e (pytorch#11741)
Summary: This diff 1) extends export_to_edge() with apply_torch_ops_passes() call We need this to grab the actual mul argument value and fusing in later passes. Otherwise, the constant gets lifted by _lower_ep_to_edge() 2) implements e2e mul fusion test. We test that both ReplaceMulTensorWithMulAndFullOpsPass() and FuseMulTensorIntoQuantPass() passes get applied correctly and which results into mul.Tensor being removed completely. Reviewed By: hsharma35 Differential Revision: D76469613
1 parent bf4acb3 commit 9a56264

File tree

3 files changed

+47
-36
lines changed

3 files changed

+47
-36
lines changed

backends/cadence/aot/compiler.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import logging
1010
from pathlib import Path
11-
from typing import Callable, cast, Optional
11+
from typing import Optional
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -32,7 +32,6 @@
3232
ExecutorchBackendConfig,
3333
ExecutorchProgramManager,
3434
)
35-
from executorch.exir.pass_base import PassResult
3635
from executorch.exir.passes import ToOutVarPass
3736
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
3837
from executorch.exir.program._program import to_edge_with_preserved_ops
@@ -41,7 +40,7 @@
4140
from torch.export.exported_program import ExportedProgram
4241
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
4342

44-
from .passes import get_cadence_passes
43+
from .passes import apply_exir_ops_passes, apply_torch_ops_passes
4544

4645
from .utils import print_ops_info
4746

@@ -256,14 +255,18 @@ def export_to_edge(
256255
inputs: tuple[object, ...],
257256
dump_graphs: bool = False,
258257
constant_methods: Optional[dict[str, object]] = None,
258+
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
259259
) -> EdgeProgramManager:
260260
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
261261

262262
# Export the model into an ExportedProgram.
263263
expo_program = trace(model, inputs)
264264

265+
# Apply passes which transform the ExportedProgram before it gets lowered to edge.
266+
expo_program = apply_torch_ops_passes(expo_program)
267+
265268
# Lower the model to edge IR.
266-
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
269+
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods, core_aten_exceptions)
267270

268271
return edge_prog_manager
269272

@@ -305,14 +308,7 @@ def _lower_ep_to_cadence(
305308
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
306309
"""
307310
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
308-
cadence_passes = get_cadence_passes(opt_level)
309-
310-
# Run a couple required passes for quant/dequant ops
311-
cadence_prog_manager = edge_prog_manager.transform(
312-
cast(
313-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
314-
)
315-
)
311+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
316312
return cadence_prog_manager
317313

318314

@@ -323,14 +319,7 @@ def export_to_cadence(
323319
opt_level: int = 1,
324320
) -> EdgeProgramManager:
325321
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
326-
cadence_passes = get_cadence_passes(opt_level)
327-
328-
# Run a couple required passes for quant/dequant ops
329-
cadence_prog_manager = edge_prog_manager.transform(
330-
cast(
331-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
332-
)
333-
)
322+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
334323
return cadence_prog_manager
335324

336325

@@ -367,15 +356,8 @@ def export_to_executorch_gen_etrecord(
367356
memory_config: Optional[MemoryConfig] = None,
368357
dump_graphs: bool = False,
369358
) -> ExecutorchProgramManager:
370-
cadence_passes = get_cadence_passes(opt_level)
371359
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
372-
373-
# Run a couple required passes for quant/dequant ops
374-
cadence_prog_manager = edge_prog_manager.transform(
375-
cast(
376-
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
377-
)
378-
)
360+
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
379361

380362
# Print some information to terminal
381363
print_ops_info(

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
11271127
FuseCascadedTransposeOrPermuteOps,
11281128
FuseCascadedViewOps,
11291129
FuseQuantDequantToRequantizePass,
1130+
FuseMulTensorIntoQuantPass,
11301131
FuseMulTensorIntoDequantPass,
11311132
FuseMulScalarIntoDequantPass,
11321133
FuseFullThenReshapePass,

backends/cadence/aot/passes.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Any, List, Optional
9+
from typing import Any, Callable, cast, List, Optional
1010

1111
import torch
1212
import torch.fx
@@ -28,13 +28,18 @@
2828
RemoveRedundantOps,
2929
)
3030
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
31-
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
31+
from executorch.backends.cadence.aot.replace_ops import (
32+
CadenceReplaceOpsInGraph,
33+
ReplaceMulTensorWithMulAndFullOpsPass,
34+
)
3235
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
36+
from executorch.exir import EdgeProgramManager
3337
from executorch.exir.pass_base import ExportPass, PassResult
3438
from executorch.exir.pass_manager import PassManager, PassType
3539
from executorch.exir.passes import dead_code_elimination_pass
3640
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
3741
from executorch.exir.passes.spec_prop_pass import SpecPropPass
42+
from torch.export.exported_program import ExportedProgram
3843

3944

4045
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
8994
return pytree.tree_flatten(passes)[0]
9095

9196

92-
def get_cadence_passes(
97+
def apply_exir_ops_passes(
9398
opt_level: int,
94-
) -> List[Optional[PassResult]]:
99+
edge_prog_manager: EdgeProgramManager,
100+
) -> EdgeProgramManager:
95101
passes = get_passes_in_default_order()
96102
pass_filter = create_cadence_pass_filter(opt_level)
97-
filtered_passes = [
98-
# pyre-ignore[20]: Expect argument graph_module
99-
filtered_pass()
103+
cadence_passes = [
104+
(
105+
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
106+
graph_module
107+
)
108+
)
100109
for filtered_pass in list(filter(pass_filter, passes))
101110
]
102-
return filtered_passes
111+
cadence_prog_manager = edge_prog_manager.transform(
112+
cast(
113+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
114+
)
115+
)
116+
return cadence_prog_manager
117+
118+
119+
def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
120+
"""
121+
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
122+
expo_program is expected to be the output of the torch.export.export().
123+
"""
124+
125+
aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
126+
ReplaceMulTensorWithMulAndFullOpsPass()
127+
]
128+
# TODO(T230417247): Use PassResult which is currently ignored.
129+
PassManager(aten_passes)(expo_program.graph_module)
130+
return expo_program

0 commit comments

Comments
 (0)