|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import contextlib | 
|  | 8 | +import os | 
|  | 9 | +import typing | 
|  | 10 | +from abc import ABC, abstractmethod | 
|  | 11 | +from enum import Enum | 
|  | 12 | +from typing import Any, Dict, List, Optional, Set | 
|  | 13 | + | 
|  | 14 | +import torch | 
|  | 15 | +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( | 
|  | 16 | +    ReplaceViewCopyWithViewPass, | 
|  | 17 | +) | 
|  | 18 | +from executorch.exir._serialize._named_data_store import NamedDataStore | 
|  | 19 | +from executorch.exir._warnings import experimental | 
|  | 20 | +from executorch.exir.backend.backend_details import ( | 
|  | 21 | +    BackendDetails, | 
|  | 22 | +    ExportedProgram, | 
|  | 23 | +    PreprocessResult, | 
|  | 24 | +) | 
|  | 25 | +from executorch.exir.backend.compile_spec_schema import CompileSpec | 
|  | 26 | +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu | 
|  | 27 | +from torch.export.passes import move_to_device_pass | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +class COMPILE_SPEC_KEYS(Enum): | 
|  | 31 | +    METHOD_NAME = "method_name" | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +@experimental( | 
|  | 35 | +    "This API and all of aoti-driven backend related functionality are experimental." | 
|  | 36 | +) | 
|  | 37 | +class AotiBackend(BackendDetails, ABC): | 
|  | 38 | +    """ | 
|  | 39 | +    Base backend class for AOTInductor-based backends. | 
|  | 40 | +
 | 
|  | 41 | +    This class provides common functionality for compiling models using AOTInductor | 
|  | 42 | +    with different device targets (CUDA, Metal/MPS, etc.). | 
|  | 43 | +    """ | 
|  | 44 | + | 
|  | 45 | +    @staticmethod | 
|  | 46 | +    @abstractmethod | 
|  | 47 | +    def get_device_name() -> str: | 
|  | 48 | +        """Return the device name for this backend (e.g., 'cuda', 'mps').""" | 
|  | 49 | +        pass | 
|  | 50 | + | 
|  | 51 | +    @staticmethod | 
|  | 52 | +    @abstractmethod | 
|  | 53 | +    def get_supported_fallback_kernels() -> Dict[str, Any]: | 
|  | 54 | +        """Return the set of supported fallback kernels for this backend.""" | 
|  | 55 | +        pass | 
|  | 56 | + | 
|  | 57 | +    @staticmethod | 
|  | 58 | +    @abstractmethod | 
|  | 59 | +    def get_decomposition_table() -> Dict[Any, Any]: | 
|  | 60 | +        """Return the decomposition table for this backend.""" | 
|  | 61 | +        pass | 
|  | 62 | + | 
|  | 63 | +    @staticmethod | 
|  | 64 | +    @abstractmethod | 
|  | 65 | +    def get_aoti_compile_options() -> Dict[str, typing.Any]: | 
|  | 66 | +        """Return the AOTInductor compilation options for this backend.""" | 
|  | 67 | +        pass | 
|  | 68 | + | 
|  | 69 | +    @classmethod | 
|  | 70 | +    @contextlib.contextmanager | 
|  | 71 | +    def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): | 
|  | 72 | +        """ | 
|  | 73 | +        Context manager to collect unsupported fallback kernels during compilation. | 
|  | 74 | +        Monitors both extern kernel calls and runtime lookup. | 
|  | 75 | +        """ | 
|  | 76 | +        supported_kernels = cls.get_supported_fallback_kernels() | 
|  | 77 | + | 
|  | 78 | +        original_generate_c_shim_extern_kernel_call = ( | 
|  | 79 | +            CppWrapperCpu.generate_c_shim_extern_kernel_call | 
|  | 80 | +        ) | 
|  | 81 | +        original_generate_fallback_kernel_with_runtime_lookup_aot = ( | 
|  | 82 | +            CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot | 
|  | 83 | +        ) | 
|  | 84 | + | 
|  | 85 | +        def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( | 
|  | 86 | +            self, | 
|  | 87 | +            kernel: str, | 
|  | 88 | +            args: list[str], | 
|  | 89 | +            device: str, | 
|  | 90 | +            *, | 
|  | 91 | +            debug_args: Optional[list[str]] = None, | 
|  | 92 | +            debug_handle: Optional[int] = None, | 
|  | 93 | +        ): | 
|  | 94 | +            if kernel not in supported_kernels: | 
|  | 95 | +                missing_fallback_kernels.add(kernel) | 
|  | 96 | + | 
|  | 97 | +            original_generate_c_shim_extern_kernel_call( | 
|  | 98 | +                self, | 
|  | 99 | +                kernel, | 
|  | 100 | +                args, | 
|  | 101 | +                device, | 
|  | 102 | +                debug_args=debug_args, | 
|  | 103 | +                debug_handle=debug_handle, | 
|  | 104 | +            ) | 
|  | 105 | + | 
|  | 106 | +        def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( | 
|  | 107 | +            self, | 
|  | 108 | +            op_overload, | 
|  | 109 | +            raw_args, | 
|  | 110 | +            output_args, | 
|  | 111 | +            raw_outputs, | 
|  | 112 | +        ): | 
|  | 113 | +            kernel_name = getattr(op_overload, "_name", str(op_overload)) | 
|  | 114 | +            if kernel_name not in supported_kernels: | 
|  | 115 | +                missing_fallback_kernels.add(kernel_name) | 
|  | 116 | + | 
|  | 117 | +            original_generate_fallback_kernel_with_runtime_lookup_aot( | 
|  | 118 | +                self, op_overload, raw_args, output_args, raw_outputs | 
|  | 119 | +            ) | 
|  | 120 | + | 
|  | 121 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 122 | +            generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels | 
|  | 123 | +        ) | 
|  | 124 | +        CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels | 
|  | 125 | + | 
|  | 126 | +        try: | 
|  | 127 | +            yield | 
|  | 128 | +        finally: | 
|  | 129 | +            CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 130 | +                original_generate_c_shim_extern_kernel_call | 
|  | 131 | +            ) | 
|  | 132 | +            CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( | 
|  | 133 | +                original_generate_fallback_kernel_with_runtime_lookup_aot | 
|  | 134 | +            ) | 
|  | 135 | + | 
|  | 136 | +    @classmethod | 
|  | 137 | +    def preprocess( | 
|  | 138 | +        cls, | 
|  | 139 | +        edge_program: ExportedProgram, | 
|  | 140 | +        compile_specs: List[CompileSpec], | 
|  | 141 | +    ) -> PreprocessResult: | 
|  | 142 | +        """ | 
|  | 143 | +        Preprocess the edge program and compile it using AOTInductor. | 
|  | 144 | +        Weights are always separated from the SO file. | 
|  | 145 | +        """ | 
|  | 146 | +        device_name = cls.get_device_name() | 
|  | 147 | +        decomposition_table = cls.get_decomposition_table() | 
|  | 148 | +        options = cls.get_aoti_compile_options() | 
|  | 149 | + | 
|  | 150 | +        # Move the edge_program to the target device | 
|  | 151 | +        device_edge_program = move_to_device_pass(edge_program, device_name) | 
|  | 152 | + | 
|  | 153 | +        # Replace view_copy with view | 
|  | 154 | +        ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) | 
|  | 155 | + | 
|  | 156 | +        # Run decompositions if any | 
|  | 157 | +        if decomposition_table: | 
|  | 158 | +            device_edge_program = device_edge_program.run_decompositions( | 
|  | 159 | +                decomposition_table | 
|  | 160 | +            ) | 
|  | 161 | + | 
|  | 162 | +        edge_program_module = device_edge_program.module() | 
|  | 163 | + | 
|  | 164 | +        # Grab all input placeholders from the graph | 
|  | 165 | +        user_input_names = device_edge_program.graph_signature.user_inputs | 
|  | 166 | +        user_input_placeholders = [] | 
|  | 167 | +        for node in device_edge_program.graph.nodes: | 
|  | 168 | +            if node.op == "placeholder" and node.name in user_input_names: | 
|  | 169 | +                user_input_placeholders.append(node.meta["val"]) | 
|  | 170 | + | 
|  | 171 | +        # Track missing fallback kernels | 
|  | 172 | +        missing_fallback_kernels: Set[str] = set() | 
|  | 173 | + | 
|  | 174 | +        # Compile with fallback kernel collection | 
|  | 175 | +        with cls.collect_unsupported_fallback_kernels( | 
|  | 176 | +            missing_fallback_kernels | 
|  | 177 | +        ), torch.no_grad(): | 
|  | 178 | +            paths = torch._inductor.aot_compile( | 
|  | 179 | +                edge_program_module, tuple(user_input_placeholders), options=options | 
|  | 180 | +            ) | 
|  | 181 | + | 
|  | 182 | +            if len(missing_fallback_kernels) > 0: | 
|  | 183 | +                formatted_kernels = "\n  - ".join(sorted(missing_fallback_kernels)) | 
|  | 184 | +                method_name = cls.method_name_from_compile_specs(compile_specs) | 
|  | 185 | +                raise RuntimeError( | 
|  | 186 | +                    f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n  - {formatted_kernels}\n" | 
|  | 187 | +                    "Please add them to the AOTI backend." | 
|  | 188 | +                ) | 
|  | 189 | + | 
|  | 190 | +        # Extract paths - weights are always separated | 
|  | 191 | +        so_path = None | 
|  | 192 | +        blob_path = None | 
|  | 193 | + | 
|  | 194 | +        if isinstance(paths, list): | 
|  | 195 | +            for path in paths: | 
|  | 196 | +                if path.endswith(".wrapper.so"): | 
|  | 197 | +                    so_path = path | 
|  | 198 | +                elif path.endswith(".wrapper_weights.blob"): | 
|  | 199 | +                    blob_path = path | 
|  | 200 | +        else: | 
|  | 201 | +            so_path = paths | 
|  | 202 | + | 
|  | 203 | +        if so_path is None or blob_path is None: | 
|  | 204 | +            raise RuntimeError( | 
|  | 205 | +                f"Could not find required files in compiled paths, got {paths}" | 
|  | 206 | +            ) | 
|  | 207 | + | 
|  | 208 | +        # Read SO file | 
|  | 209 | +        with open(so_path, "rb") as f: | 
|  | 210 | +            so_data = f.read() | 
|  | 211 | + | 
|  | 212 | +        # Read weights blob | 
|  | 213 | +        with open(blob_path, "rb") as f: | 
|  | 214 | +            blob_data = f.read() | 
|  | 215 | + | 
|  | 216 | +        # Create named data store | 
|  | 217 | +        named_data_store = NamedDataStore() | 
|  | 218 | +        method_name = cls.method_name_from_compile_specs(compile_specs) | 
|  | 219 | + | 
|  | 220 | +        # Add SO and weights blob separately | 
|  | 221 | +        named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) | 
|  | 222 | +        weights_blob_data_type = f"aoti_{device_name}_blob" | 
|  | 223 | +        named_data_store.add_named_data( | 
|  | 224 | +            method_name + "_weights_blob", blob_data, 1, weights_blob_data_type | 
|  | 225 | +        ) | 
|  | 226 | + | 
|  | 227 | +        # Clean up the generated files | 
|  | 228 | +        os.remove(so_path) | 
|  | 229 | +        os.remove(blob_path) | 
|  | 230 | + | 
|  | 231 | +        return PreprocessResult( | 
|  | 232 | +            processed_bytes=b"", | 
|  | 233 | +            debug_handle_map={}, | 
|  | 234 | +            data_store_output=named_data_store.get_named_data_store_output(), | 
|  | 235 | +        ) | 
|  | 236 | + | 
|  | 237 | +    @staticmethod | 
|  | 238 | +    def generate_method_name_compile_spec( | 
|  | 239 | +        method_name: str, | 
|  | 240 | +    ) -> CompileSpec: | 
|  | 241 | +        """ | 
|  | 242 | +        Generate a CompileSpec for the given method name. | 
|  | 243 | +        """ | 
|  | 244 | +        return CompileSpec( | 
|  | 245 | +            COMPILE_SPEC_KEYS.METHOD_NAME.value, | 
|  | 246 | +            method_name.encode("utf-8"), | 
|  | 247 | +        ) | 
|  | 248 | + | 
|  | 249 | +    @staticmethod | 
|  | 250 | +    def method_name_from_compile_specs( | 
|  | 251 | +        compile_specs: List[CompileSpec], | 
|  | 252 | +    ) -> str: | 
|  | 253 | +        """ | 
|  | 254 | +        Extract the method name from the compile specs. | 
|  | 255 | +        """ | 
|  | 256 | +        for spec in compile_specs: | 
|  | 257 | +            if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: | 
|  | 258 | +                return spec.value.decode("utf-8") | 
|  | 259 | +        raise RuntimeError( | 
|  | 260 | +            f"Could not find method name in compile specs: {compile_specs}" | 
|  | 261 | +        ) | 
0 commit comments