Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion exir/program/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import copy
import io
import logging
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Union
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union

import torch
import torch._export
from executorch.exir._serialize import _serialize_pte_binary
from executorch.exir._serialize._cord import Cord
from executorch.exir._warnings import experimental
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.partitioner import Partitioner
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
Expand Down Expand Up @@ -1057,6 +1058,54 @@ def to_edge_transform_and_lower(
return edge_manager


@experimental(
"""
This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
This function will be combined with to_edge in the future.
"""
)
def to_edge_with_preserved_ops(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
compile_config: Optional[EdgeCompileConfig] = None,
preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
) -> "EdgeProgramManager":
"""
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
ATen dialect. Upon construction those programs are transformed into edge dialect.
Args:
programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
preserve_ops: An argument used to specify ops that should not be decomposed.
Returns:
EdgeProgramManager
"""
assert not isinstance(constant_methods, EdgeCompileConfig)
config = compile_config or EdgeCompileConfig()
if not isinstance(programs, dict):
aten_programs = {"forward": programs}
else:
aten_programs = programs

edge_programs: Dict[str, ExportedProgram] = {}

for name, program in aten_programs.items():
# Decompose to Core ATen
program = program.run_decompositions(
_default_decomposition_table(), _preserve_ops=preserve_ops
)
edge_programs[name] = _generate_edge_program(
name, config, program, list(preserve_ops)
)

return EdgeProgramManager(
edge_programs, constant_methods, config, list(preserve_ops)
)


def to_edge(
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
constant_methods: Optional[Dict[str, Any]] = None,
Expand Down
87 changes: 87 additions & 0 deletions exir/program/test/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ExecutorchProgramManager,
to_edge,
to_edge_transform_and_lower,
to_edge_with_preserved_ops,
)
from executorch.exir.tracer import _default_decomposition_table
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
Expand Down Expand Up @@ -716,3 +717,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)
except SpecViolationError:
self.fail("Should not error out on linalg_vector_norm op")

def _test_to_edge_with_preserved_ops(
self, program, preserved_ops, expected_preserved_ops
):
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)

def count_nodes(graph_module, target):
count = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target in target:
count += 1
return count

aten_ops_non_decomposed = count_nodes(
program.graph_module,
preserved_ops,
)

edge_ops_non_decomposed = count_nodes(
edge.exported_program().graph_module,
expected_preserved_ops,
)

self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed)

def test_to_edge_with_single_preserved_op(self):
model = TestLinear()
program = torch.export.export(model, model._get_random_inputs())

ops_not_to_decompose = [
torch.ops.aten.linear.default,
]
expected_non_decomposed_edge_ops = [
exir_ops.edge.aten.linear.default,
]

self._test_to_edge_with_preserved_ops(
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)

def test_to_edge_with_partial_ops_preserved(self):
model = TestLinearSDPACombined()
program = torch.export.export(model, model._get_random_inputs())

ops_not_to_decompose = [
torch.ops.aten.linear.default,
]
expected_non_decomposed_edge_ops = [
exir_ops.edge.aten.linear.default,
]

self._test_to_edge_with_preserved_ops(
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)

def test_to_edge_with_multiple_ops_preserved(self):
model = TestLinearSDPACombined()
program = torch.export.export(model, model._get_random_inputs())

ops_not_to_decompose = [
torch.ops.aten.linear.default,
torch.ops.aten.scaled_dot_product_attention.default,
]
expected_non_decomposed_edge_ops = [
exir_ops.edge.aten.linear.default,
exir_ops.edge.aten.scaled_dot_product_attention.default,
]

self._test_to_edge_with_preserved_ops(
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)

def test_to_edge_with_preserved_ops_not_in_model(self):
model = TestSDPA()
program = torch.export.export(model, model._get_random_inputs())

ops_not_to_decompose = [
torch.ops.aten.linear.default,
]
expected_non_decomposed_edge_ops = [
exir_ops.edge.aten.linear.default,
]

self._test_to_edge_with_preserved_ops(
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
)
Loading