diff --git a/exir/program/_program.py b/exir/program/_program.py index 6b72d190f9d..144cd0d0e8e 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -9,12 +9,13 @@ import copy import io import logging -from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Union +from typing import Any, Dict, List, Optional, Sequence, Set, TextIO, Tuple, Union import torch import torch._export from executorch.exir._serialize import _serialize_pte_binary from executorch.exir._serialize._cord import Cord +from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.partitioner import Partitioner from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig @@ -1057,6 +1058,54 @@ def to_edge_transform_and_lower( return edge_manager +@experimental( + """ + This is an experimental API which overloads to_edge by preserving specified ops to not be decomposed. + This function will be combined with to_edge in the future. + """ +) +def to_edge_with_preserved_ops( + programs: Union[ExportedProgram, Dict[str, ExportedProgram]], + constant_methods: Optional[Dict[str, Any]] = None, + compile_config: Optional[EdgeCompileConfig] = None, + preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), +) -> "EdgeProgramManager": + """ + :func:`to_edge` constructs an EdgeProgramManager from a set of exported programs in + ATen dialect. Upon construction those programs are transformed into edge dialect. + + Args: + programs: Can be a single ExportedProgram or a dictionary mapping function names to their corresponding ExportedPrograms. If only a single ExportedProgram is provided it will be assigned the name "forward". + constant_methods: An optional dictionary of method name to the constant value returned by that method in eager mode. Often used to store config information on Edge models. + compile_config: An optional argument used to provide greater control over the transformation to edge dialect process. + preserve_ops: An argument used to specify ops that should not be decomposed. + + Returns: + EdgeProgramManager + """ + assert not isinstance(constant_methods, EdgeCompileConfig) + config = compile_config or EdgeCompileConfig() + if not isinstance(programs, dict): + aten_programs = {"forward": programs} + else: + aten_programs = programs + + edge_programs: Dict[str, ExportedProgram] = {} + + for name, program in aten_programs.items(): + # Decompose to Core ATen + program = program.run_decompositions( + _default_decomposition_table(), _preserve_ops=preserve_ops + ) + edge_programs[name] = _generate_edge_program( + name, config, program, list(preserve_ops) + ) + + return EdgeProgramManager( + edge_programs, constant_methods, config, list(preserve_ops) + ) + + def to_edge( programs: Union[ExportedProgram, Dict[str, ExportedProgram]], constant_methods: Optional[Dict[str, Any]] = None, diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 73f023e778b..edad4b24f1c 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -26,6 +26,7 @@ ExecutorchProgramManager, to_edge, to_edge_transform_and_lower, + to_edge_with_preserved_ops, ) from executorch.exir.tracer import _default_decomposition_table from executorch.exir.verification.verifier import EXIREdgeDialectVerifier @@ -716,3 +717,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) except SpecViolationError: self.fail("Should not error out on linalg_vector_norm op") + + def _test_to_edge_with_preserved_ops( + self, program, preserved_ops, expected_preserved_ops + ): + edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops) + + def count_nodes(graph_module, target): + count = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target in target: + count += 1 + return count + + aten_ops_non_decomposed = count_nodes( + program.graph_module, + preserved_ops, + ) + + edge_ops_non_decomposed = count_nodes( + edge.exported_program().graph_module, + expected_preserved_ops, + ) + + self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed) + + def test_to_edge_with_single_preserved_op(self): + model = TestLinear() + program = torch.export.export(model, model._get_random_inputs()) + + ops_not_to_decompose = [ + torch.ops.aten.linear.default, + ] + expected_non_decomposed_edge_ops = [ + exir_ops.edge.aten.linear.default, + ] + + self._test_to_edge_with_preserved_ops( + program, ops_not_to_decompose, expected_non_decomposed_edge_ops + ) + + def test_to_edge_with_partial_ops_preserved(self): + model = TestLinearSDPACombined() + program = torch.export.export(model, model._get_random_inputs()) + + ops_not_to_decompose = [ + torch.ops.aten.linear.default, + ] + expected_non_decomposed_edge_ops = [ + exir_ops.edge.aten.linear.default, + ] + + self._test_to_edge_with_preserved_ops( + program, ops_not_to_decompose, expected_non_decomposed_edge_ops + ) + + def test_to_edge_with_multiple_ops_preserved(self): + model = TestLinearSDPACombined() + program = torch.export.export(model, model._get_random_inputs()) + + ops_not_to_decompose = [ + torch.ops.aten.linear.default, + torch.ops.aten.scaled_dot_product_attention.default, + ] + expected_non_decomposed_edge_ops = [ + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten.scaled_dot_product_attention.default, + ] + + self._test_to_edge_with_preserved_ops( + program, ops_not_to_decompose, expected_non_decomposed_edge_ops + ) + + def test_to_edge_with_preserved_ops_not_in_model(self): + model = TestSDPA() + program = torch.export.export(model, model._get_random_inputs()) + + ops_not_to_decompose = [ + torch.ops.aten.linear.default, + ] + expected_non_decomposed_edge_ops = [ + exir_ops.edge.aten.linear.default, + ] + + self._test_to_edge_with_preserved_ops( + program, ops_not_to_decompose, expected_non_decomposed_edge_ops + )