Skip to content

Commit 9301ebb

Browse files
David Linfacebook-github-bot
authored andcommitted
Add experimental API for preserving ops from decomposition (#5236)
Summary: Pull Request resolved: #5236 This experimental API passes a tuple of ops to `program.run_decomposition` which prevents the list of provided ops from being decomposed. Once this API is no longer experimental, the new param will be moved to `to_edge` and this method will be removed. Reviewed By: tarun292 Differential Revision: D62460623 fbshipit-source-id: b4ce1a1a5bcf4064cc0e4ff251ce5a56cd0fad7d
1 parent ca2ac54 commit 9301ebb

File tree

2 files changed

+137
-1
lines changed

2 files changed

+137
-1
lines changed

exir/program/_program.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
import copy
1010
import io
1111
import logging
12-
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Union
12+
from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union
1313

1414
import torch
1515
import torch._export
1616
from executorch.exir._serialize import _serialize_pte_binary
1717
from executorch.exir._serialize._cord import Cord
18+
from executorch.exir._warnings import experimental
1819
from executorch.exir.backend.backend_api import to_backend
1920
from executorch.exir.backend.partitioner import Partitioner
2021
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
@@ -1057,6 +1058,54 @@ def to_edge_transform_and_lower(
10571058
return edge_manager
10581059

10591060

1061+
@experimental(
1062+
"""
1063+
This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed.
1064+
This function will be combined with to_edge in the future.
1065+
"""
1066+
)
1067+
def to_edge_with_preserved_ops(
1068+
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
1069+
constant_methods: Optional[Dict[str, Any]] = None,
1070+
compile_config: Optional[EdgeCompileConfig] = None,
1071+
preserve_ops: Tuple[torch._ops.OpOverload, ...] = (),
1072+
) -> "EdgeProgramManager":
1073+
"""
1074+
:func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in
1075+
ATen dialect. Upon construction those programs are transformed into edge dialect.
1076+
1077+
Args:
1078+
programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward".
1079+
constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models.
1080+
compile_config: An optional argument used to provide greater control over the transformation to edge dialect process.
1081+
preserve_ops: An argument used to specify ops that should not be decomposed.
1082+
1083+
Returns:
1084+
EdgeProgramManager
1085+
"""
1086+
assert not isinstance(constant_methods, EdgeCompileConfig)
1087+
config = compile_config or EdgeCompileConfig()
1088+
if not isinstance(programs, dict):
1089+
aten_programs = {"forward": programs}
1090+
else:
1091+
aten_programs = programs
1092+
1093+
edge_programs: Dict[str, ExportedProgram] = {}
1094+
1095+
for name, program in aten_programs.items():
1096+
# Decompose to Core ATen
1097+
program = program.run_decompositions(
1098+
_default_decomposition_table(), _preserve_ops=preserve_ops
1099+
)
1100+
edge_programs[name] = _generate_edge_program(
1101+
name, config, program, list(preserve_ops)
1102+
)
1103+
1104+
return EdgeProgramManager(
1105+
edge_programs, constant_methods, config, list(preserve_ops)
1106+
)
1107+
1108+
10601109
def to_edge(
10611110
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
10621111
constant_methods: Optional[Dict[str, Any]] = None,

exir/program/test/test_program.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
ExecutorchProgramManager,
2727
to_edge,
2828
to_edge_transform_and_lower,
29+
to_edge_with_preserved_ops,
2930
)
3031
from executorch.exir.tracer import _default_decomposition_table
3132
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
@@ -716,3 +717,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
716717
)
717718
except SpecViolationError:
718719
self.fail("Should not error out on linalg_vector_norm op")
720+
721+
def _test_to_edge_with_preserved_ops(
722+
self, program, preserved_ops, expected_preserved_ops
723+
):
724+
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
725+
726+
def count_nodes(graph_module, target):
727+
count = 0
728+
for node in graph_module.graph.nodes:
729+
if node.op == "call_function" and node.target in target:
730+
count += 1
731+
return count
732+
733+
aten_ops_non_decomposed = count_nodes(
734+
program.graph_module,
735+
preserved_ops,
736+
)
737+
738+
edge_ops_non_decomposed = count_nodes(
739+
edge.exported_program().graph_module,
740+
expected_preserved_ops,
741+
)
742+
743+
self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed)
744+
745+
def test_to_edge_with_single_preserved_op(self):
746+
model = TestLinear()
747+
program = torch.export.export(model, model._get_random_inputs())
748+
749+
ops_not_to_decompose = [
750+
torch.ops.aten.linear.default,
751+
]
752+
expected_non_decomposed_edge_ops = [
753+
exir_ops.edge.aten.linear.default,
754+
]
755+
756+
self._test_to_edge_with_preserved_ops(
757+
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
758+
)
759+
760+
def test_to_edge_with_partial_ops_preserved(self):
761+
model = TestLinearSDPACombined()
762+
program = torch.export.export(model, model._get_random_inputs())
763+
764+
ops_not_to_decompose = [
765+
torch.ops.aten.linear.default,
766+
]
767+
expected_non_decomposed_edge_ops = [
768+
exir_ops.edge.aten.linear.default,
769+
]
770+
771+
self._test_to_edge_with_preserved_ops(
772+
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
773+
)
774+
775+
def test_to_edge_with_multiple_ops_preserved(self):
776+
model = TestLinearSDPACombined()
777+
program = torch.export.export(model, model._get_random_inputs())
778+
779+
ops_not_to_decompose = [
780+
torch.ops.aten.linear.default,
781+
torch.ops.aten.scaled_dot_product_attention.default,
782+
]
783+
expected_non_decomposed_edge_ops = [
784+
exir_ops.edge.aten.linear.default,
785+
exir_ops.edge.aten.scaled_dot_product_attention.default,
786+
]
787+
788+
self._test_to_edge_with_preserved_ops(
789+
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
790+
)
791+
792+
def test_to_edge_with_preserved_ops_not_in_model(self):
793+
model = TestSDPA()
794+
program = torch.export.export(model, model._get_random_inputs())
795+
796+
ops_not_to_decompose = [
797+
torch.ops.aten.linear.default,
798+
]
799+
expected_non_decomposed_edge_ops = [
800+
exir_ops.edge.aten.linear.default,
801+
]
802+
803+
self._test_to_edge_with_preserved_ops(
804+
program, ops_not_to_decompose, expected_non_decomposed_edge_ops
805+
)

0 commit comments

Comments
 (0)