Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 12 additions & 28 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import logging
from pathlib import Path
from typing import Callable, cast, Optional
from typing import Optional

import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
Expand All @@ -32,7 +32,6 @@
ExecutorchBackendConfig,
ExecutorchProgramManager,
)
from executorch.exir.pass_base import PassResult
from executorch.exir.passes import ToOutVarPass
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
from executorch.exir.program._program import to_edge_with_preserved_ops
Expand All @@ -41,7 +40,7 @@
from torch.export.exported_program import ExportedProgram
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e

from .passes import get_cadence_passes
from .passes import apply_exir_ops_passes, apply_torch_ops_passes

from .utils import print_ops_info

Expand Down Expand Up @@ -256,14 +255,20 @@ def export_to_edge(
inputs: tuple[object, ...],
dump_graphs: bool = False,
constant_methods: Optional[dict[str, object]] = None,
core_aten_exceptions: Optional[list[torch._ops.OpOverload]] = None,
) -> EdgeProgramManager:
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Export the model into an ExportedProgram.
expo_program = trace(model, inputs)

# Apply passes which transform the ExportedProgram before it gets lowered to edge.
expo_program = apply_torch_ops_passes(expo_program)

# Lower the model to edge IR.
edge_prog_manager = _lower_ep_to_edge(expo_program, dump_graphs, constant_methods)
edge_prog_manager = _lower_ep_to_edge(
expo_program, dump_graphs, constant_methods, core_aten_exceptions
)

return edge_prog_manager

Expand Down Expand Up @@ -305,14 +310,7 @@ def _lower_ep_to_cadence(
Lower an existing ExportedProgram to edge IR and apply frontend optimization passes.
"""
edge_prog_manager = _lower_ep_to_edge(program, dump_graphs=dump_graphs)
cadence_passes = get_cadence_passes(opt_level)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
return cadence_prog_manager


Expand All @@ -323,14 +321,7 @@ def export_to_cadence(
opt_level: int = 1,
) -> EdgeProgramManager:
edge_prog_manager = export_to_edge(model, inputs, dump_graphs=dump_graphs)
cadence_passes = get_cadence_passes(opt_level)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)
return cadence_prog_manager


Expand Down Expand Up @@ -367,15 +358,8 @@ def export_to_executorch_gen_etrecord(
memory_config: Optional[MemoryConfig] = None,
dump_graphs: bool = False,
) -> ExecutorchProgramManager:
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
cadence_prog_manager = apply_exir_ops_passes(opt_level, edge_prog_manager)

# Print some information to terminal
print_ops_info(
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,7 @@ class CadenceFuseOpsInGraph:
FuseCascadedTransposeOrPermuteOps,
FuseCascadedViewOps,
FuseQuantDequantToRequantizePass,
FuseMulTensorIntoQuantPass,
FuseMulTensorIntoDequantPass,
FuseMulScalarIntoDequantPass,
FuseFullThenReshapePass,
Expand Down
44 changes: 36 additions & 8 deletions backends/cadence/aot/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Any, List, Optional
from typing import Any, Callable, cast, List, Optional

import torch
import torch.fx
Expand All @@ -28,13 +28,18 @@
RemoveRedundantOps,
)
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
from executorch.backends.cadence.aot.replace_ops import CadenceReplaceOpsInGraph
from executorch.backends.cadence.aot.replace_ops import (
CadenceReplaceOpsInGraph,
ReplaceMulTensorWithMulAndFullOpsPass,
)
from executorch.backends.cadence.aot.simplify_ops import CadenceSimplifyOpsInGraph
from executorch.exir import EdgeProgramManager
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch.export.exported_program import ExportedProgram


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -89,14 +94,37 @@ def get_passes_in_default_order() -> List[ExportPass]:
return pytree.tree_flatten(passes)[0]


def get_cadence_passes(
def apply_exir_ops_passes(
opt_level: int,
) -> List[Optional[PassResult]]:
edge_prog_manager: EdgeProgramManager,
) -> EdgeProgramManager:
passes = get_passes_in_default_order()
pass_filter = create_cadence_pass_filter(opt_level)
filtered_passes = [
# pyre-ignore[20]: Expect argument graph_module
filtered_pass()
cadence_passes = [
(
lambda graph_module, filtered_pass=filtered_pass: filtered_pass()(
graph_module
)
)
for filtered_pass in list(filter(pass_filter, passes))
]
return filtered_passes
cadence_prog_manager = edge_prog_manager.transform(
cast(
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
)
)
return cadence_prog_manager


def apply_torch_ops_passes(expo_program: ExportedProgram) -> ExportedProgram:
"""
Applies compiler passes on torch.ops IR, including torch.ops.aten, torch.ops.cadence, etc.
expo_program is expected to be the output of the torch.export.export().
"""

aten_passes: List[Callable[[torch.fx.GraphModule], Optional[PassResult]]] = [
ReplaceMulTensorWithMulAndFullOpsPass()
]
# TODO(T230417247): Use PassResult which is currently ignored.
PassManager(aten_passes)(expo_program.graph_module)
return expo_program
Loading