| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +from typing import Callable, Dict, List, Optional, Tuple  | 
 | 8 | + | 
 | 9 | +import torch  | 
 | 10 | +from executorch.exir._warnings import experimental  | 
 | 11 | +from executorch.exir.backend.compile_spec_schema import CompileSpec  | 
 | 12 | +from executorch.exir.backend.partitioner import (  | 
 | 13 | +    DelegationSpec,  | 
 | 14 | +    Partitioner,  | 
 | 15 | +    PartitionResult,  | 
 | 16 | +)  | 
 | 17 | +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer  | 
 | 18 | +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param  | 
 | 19 | +from torch.export.exported_program import ExportedProgram  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +@experimental(  | 
 | 23 | +    "This API and all of cuda backend related functionality are experimental."  | 
 | 24 | +)  | 
 | 25 | +class AotiPartitioner(Partitioner):  | 
 | 26 | +    """  | 
 | 27 | +    Base partitioner for AOTInductor-driven backend integration.  | 
 | 28 | +
  | 
 | 29 | +    This partitioner creates a single partition containing all operators from the input graph.  | 
 | 30 | +    It skips core ATen decomposition, allowing the backend to handle decomposition using  | 
 | 31 | +    AOTInductor's backend-specific decomposition table.  | 
 | 32 | +
  | 
 | 33 | +    Only operators that cannot be handled by the aoti library will be excluded from  | 
 | 34 | +    the partition and fall back to ExecuTorch's default or custom handling.  | 
 | 35 | +    """  | 
 | 36 | + | 
 | 37 | +    def __init__(self, backend_name: str, compile_spec: List[CompileSpec]) -> None:  | 
 | 38 | +        """  | 
 | 39 | +        Initialize the AOTI partitioner.  | 
 | 40 | +
  | 
 | 41 | +        Args:  | 
 | 42 | +            backend_name: The name of the backend (e.g., "CudaBackend", "MetalBackend")  | 
 | 43 | +            compile_spec: List of compilation specifications  | 
 | 44 | +        """  | 
 | 45 | +        self.delegation_spec = DelegationSpec(backend_name, compile_spec)  | 
 | 46 | + | 
 | 47 | +    def partition(self, exported_program: ExportedProgram) -> PartitionResult:  | 
 | 48 | +        """  | 
 | 49 | +        Fully delegate the graph to AOTInductor by tagging all nodes as a single partition.  | 
 | 50 | +        """  | 
 | 51 | + | 
 | 52 | +        partition_tags: Dict[str, DelegationSpec] = {}  | 
 | 53 | +        tag = "tag0"  | 
 | 54 | + | 
 | 55 | +        for node in exported_program.graph.nodes:  | 
 | 56 | +            if node.op != "call_function":  | 
 | 57 | +                continue  | 
 | 58 | +            node.meta["delegation_tag"] = tag  | 
 | 59 | + | 
 | 60 | +        partition_tags[tag] = self.delegation_spec  | 
 | 61 | + | 
 | 62 | +        tag_constant_data(exported_program)  | 
 | 63 | +        tag_mutated_buffer(exported_program)  | 
 | 64 | + | 
 | 65 | +        # Tag constant placeholders that have no users  | 
 | 66 | +        # tag_constant_data only tags constants that have users with delegation_tag  | 
 | 67 | +        # but we need to tag all constants for this partition  | 
 | 68 | +        for node in exported_program.graph.nodes:  | 
 | 69 | +            if node.op == "placeholder" and (  | 
 | 70 | +                is_param(exported_program, node)  | 
 | 71 | +                or is_buffer(exported_program, node)  | 
 | 72 | +                or is_lifted_tensor_constant(exported_program, node)  | 
 | 73 | +            ):  | 
 | 74 | +                if "delegation_tag" not in node.meta:  | 
 | 75 | +                    node.meta["delegation_tag"] = tag  | 
 | 76 | + | 
 | 77 | +        return PartitionResult(  | 
 | 78 | +            tagged_exported_program=exported_program, partition_tags=partition_tags  | 
 | 79 | +        )  | 
 | 80 | + | 
 | 81 | +    def ops_to_not_decompose(  | 
 | 82 | +        self, ep: ExportedProgram  | 
 | 83 | +    ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:  | 
 | 84 | +        """  | 
 | 85 | +        Return a list of operations that should not be decomposed and let the AOT compiler handle them.  | 
 | 86 | +        Currently we skip ATen decompositon for all ops, and let the backend handle them.  | 
 | 87 | +        """  | 
 | 88 | +        do_not_decompose = set()  | 
 | 89 | + | 
 | 90 | +        for node in ep.graph.nodes:  | 
 | 91 | +            if node.op == "call_function" and isinstance(  | 
 | 92 | +                node.target, torch._ops.OpOverload  | 
 | 93 | +            ):  | 
 | 94 | +                do_not_decompose.add(node.target)  | 
 | 95 | +        return list(do_not_decompose), None  | 
0 commit comments