|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import contextlib | 
|  | 8 | +import os | 
|  | 9 | +import typing | 
|  | 10 | + | 
|  | 11 | +from typing import Any, Dict, final, List, Optional, Set | 
|  | 12 | + | 
|  | 13 | +import torch | 
|  | 14 | +from executorch.exir._serialize._named_data_store import NamedDataStore | 
|  | 15 | +from executorch.exir.backend.backend_details import ( | 
|  | 16 | +    BackendDetails, | 
|  | 17 | +    ExportedProgram, | 
|  | 18 | +    PreprocessResult, | 
|  | 19 | +) | 
|  | 20 | +from executorch.exir.backend.compile_spec_schema import CompileSpec | 
|  | 21 | +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu | 
|  | 22 | +from torch.export.passes import move_to_device_pass | 
|  | 23 | + | 
|  | 24 | + | 
|  | 25 | +# exist fallback operators in et namespace; | 
|  | 26 | +supported_fallback_kernels: Dict[str, Any] = {} | 
|  | 27 | + | 
|  | 28 | +# required fallback kernels but not supported | 
|  | 29 | +missing_fallback_kernels: Set[str] = set() | 
|  | 30 | + | 
|  | 31 | + | 
|  | 32 | +# context manager for non-fallback guarantee | 
|  | 33 | +# it will raise exception when generating fallback kernels during aoti compile | 
|  | 34 | +@contextlib.contextmanager | 
|  | 35 | +def collect_unsupported_fallback_kernels(): | 
|  | 36 | +    original_generate_c_shim_extern_kernel_call = ( | 
|  | 37 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call | 
|  | 38 | +    ) | 
|  | 39 | +    original_generate_fallback_kernel_with_runtime_lookup_aot = ( | 
|  | 40 | +        CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot | 
|  | 41 | +    ) | 
|  | 42 | + | 
|  | 43 | +    def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( | 
|  | 44 | +        self, | 
|  | 45 | +        kernel: str, | 
|  | 46 | +        args: list[str], | 
|  | 47 | +        device: str, | 
|  | 48 | +        *, | 
|  | 49 | +        debug_args: Optional[list[str]] = None, | 
|  | 50 | +    ): | 
|  | 51 | +        if kernel not in supported_fallback_kernels: | 
|  | 52 | +            missing_fallback_kernels.add(kernel) | 
|  | 53 | + | 
|  | 54 | +        original_generate_c_shim_extern_kernel_call( | 
|  | 55 | +            self, kernel, args, device, debug_args=debug_args | 
|  | 56 | +        ) | 
|  | 57 | + | 
|  | 58 | +    def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( | 
|  | 59 | +        self, | 
|  | 60 | +        op_overload, | 
|  | 61 | +        raw_args, | 
|  | 62 | +        output_args, | 
|  | 63 | +        raw_outputs, | 
|  | 64 | +    ): | 
|  | 65 | +        # Extract kernel name for collection | 
|  | 66 | +        kernel_name = getattr(op_overload, "_name", str(op_overload)) | 
|  | 67 | +        if kernel_name not in supported_fallback_kernels: | 
|  | 68 | +            missing_fallback_kernels.add(kernel_name) | 
|  | 69 | + | 
|  | 70 | +        original_generate_fallback_kernel_with_runtime_lookup_aot( | 
|  | 71 | +            self, op_overload, raw_args, output_args, raw_outputs | 
|  | 72 | +        ) | 
|  | 73 | + | 
|  | 74 | +    CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 75 | +        generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels | 
|  | 76 | +    ) | 
|  | 77 | +    CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( | 
|  | 78 | +        generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels | 
|  | 79 | +    ) | 
|  | 80 | +    try: | 
|  | 81 | +        yield | 
|  | 82 | +    finally: | 
|  | 83 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 84 | +            original_generate_c_shim_extern_kernel_call | 
|  | 85 | +        ) | 
|  | 86 | +        CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( | 
|  | 87 | +            original_generate_fallback_kernel_with_runtime_lookup_aot | 
|  | 88 | +        ) | 
|  | 89 | + | 
|  | 90 | + | 
|  | 91 | +@final | 
|  | 92 | +class CudaBackend(BackendDetails): | 
|  | 93 | +    """ | 
|  | 94 | +    CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate | 
|  | 95 | +    optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices | 
|  | 96 | +    using the Executorch runtime. | 
|  | 97 | +    """ | 
|  | 98 | + | 
|  | 99 | +    @staticmethod | 
|  | 100 | +    def preprocess( | 
|  | 101 | +        edge_program: ExportedProgram, | 
|  | 102 | +        compile_specs: List[CompileSpec], | 
|  | 103 | +    ) -> PreprocessResult: | 
|  | 104 | +        # Move the edge_program from CPU to CUDA for aoti compile | 
|  | 105 | +        cuda_edge_program = move_to_device_pass(edge_program, "cuda") | 
|  | 106 | + | 
|  | 107 | +        edge_program_module = cuda_edge_program.module() | 
|  | 108 | + | 
|  | 109 | +        # Grab all input placeholders from the graph | 
|  | 110 | +        user_input_names = cuda_edge_program.graph_signature.user_inputs | 
|  | 111 | +        user_input_placeholders = [] | 
|  | 112 | +        for node in cuda_edge_program.graph.nodes: | 
|  | 113 | +            if node.op == "placeholder" and node.name in user_input_names: | 
|  | 114 | +                user_input_placeholders.append(node.meta["val"]) | 
|  | 115 | + | 
|  | 116 | +        # Create pseudo user inputs using torch.randn and metadata from input placeholders | 
|  | 117 | +        faked_user_inputs = [] | 
|  | 118 | +        for placeholder in user_input_placeholders: | 
|  | 119 | +            if isinstance(placeholder, torch.Tensor): | 
|  | 120 | +                # Generate fake input with same shape and dtype, on CUDA | 
|  | 121 | +                fake_input = torch.randn( | 
|  | 122 | +                    placeholder.shape, dtype=placeholder.dtype, device="cuda" | 
|  | 123 | +                ) | 
|  | 124 | +                faked_user_inputs.append(fake_input) | 
|  | 125 | + | 
|  | 126 | +        faked_user_inputs = tuple(faked_user_inputs) | 
|  | 127 | + | 
|  | 128 | +        options: dict[str, typing.Any] = { | 
|  | 129 | +            # Embed CUDA kernel binaries directly into the compiled shared object | 
|  | 130 | +            "aot_inductor.embed_kernel_binary": True, | 
|  | 131 | +            # Do not link against the full PyTorch/libtorch library | 
|  | 132 | +            "aot_inductor.link_libtorch": False, | 
|  | 133 | +            # Package model constants and other generated files directly in the shared object (.so) file | 
|  | 134 | +            "aot_inductor.package_constants_in_so": True, | 
|  | 135 | +            # Enable maximum automatic tuning for optimal performance | 
|  | 136 | +            "max_autotune": True, | 
|  | 137 | +            # Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch | 
|  | 138 | +            "max_autotune_gemm_backends": "TRITON", | 
|  | 139 | +            # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch | 
|  | 140 | +            "max_autotune_conv_backends": "TRITON", | 
|  | 141 | +        } | 
|  | 142 | + | 
|  | 143 | +        with collect_unsupported_fallback_kernels(): | 
|  | 144 | +            so_path = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options)  # type: ignore[arg-type] | 
|  | 145 | +            if len(missing_fallback_kernels) > 0: | 
|  | 146 | +                formatted_kernels = "\n  - ".join(sorted(missing_fallback_kernels)) | 
|  | 147 | +                raise RuntimeError( | 
|  | 148 | +                    f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n  - {formatted_kernels}\n" | 
|  | 149 | +                    "Please add them to the AOTI backend." | 
|  | 150 | +                ) | 
|  | 151 | + | 
|  | 152 | +        # pyre-ignorep[6]: Incompatible parameter type | 
|  | 153 | +        with open(so_path, "rb") as f: | 
|  | 154 | +            so_data = f.read() | 
|  | 155 | + | 
|  | 156 | +        named_data_store = NamedDataStore() | 
|  | 157 | +        named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") | 
|  | 158 | + | 
|  | 159 | +        # Clean up the generated so file; it has been packaged into the NamdeDataStore | 
|  | 160 | +        # pyre-ignorep[6]: Incompatible parameter type | 
|  | 161 | +        os.remove(so_path) | 
|  | 162 | + | 
|  | 163 | +        return PreprocessResult( | 
|  | 164 | +            processed_bytes=b"", | 
|  | 165 | +            debug_handle_map={}, | 
|  | 166 | +            data_store_output=named_data_store.get_named_data_store_output(), | 
|  | 167 | +        ) | 
0 commit comments