|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | +import contextlib | 
|  | 8 | +import os | 
|  | 9 | +import tempfile | 
|  | 10 | +import typing | 
|  | 11 | + | 
|  | 12 | +from typing import Any, Dict, final, List, Optional, Set | 
|  | 13 | + | 
|  | 14 | +import torch | 
|  | 15 | +from executorch.exir._serialize._named_data_store import NamedDataStore | 
|  | 16 | +from executorch.exir.backend.backend_details import ( | 
|  | 17 | +    BackendDetails, | 
|  | 18 | +    ExportedProgram, | 
|  | 19 | +    PreprocessResult, | 
|  | 20 | +) | 
|  | 21 | +from executorch.exir.backend.compile_spec_schema import CompileSpec | 
|  | 22 | +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu | 
|  | 23 | +from torch.export.passes import move_to_device_pass | 
|  | 24 | + | 
|  | 25 | + | 
|  | 26 | +# exist fallback operators in et namespace; | 
|  | 27 | +supported_fallback_kernels: Dict[str, Any] = {} | 
|  | 28 | + | 
|  | 29 | +# required fallback kernels but not supported | 
|  | 30 | +missing_fallback_kernels: Set[str] = set() | 
|  | 31 | + | 
|  | 32 | + | 
|  | 33 | +# context manager for non-fallback guarantee | 
|  | 34 | +# it will raise exception when generating fallback kernels during aoti compile | 
|  | 35 | +@contextlib.contextmanager | 
|  | 36 | +def collect_unsupported_fallback_kernels(): | 
|  | 37 | +    original_generate_c_shim_extern_kernel_call = ( | 
|  | 38 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call | 
|  | 39 | +    ) | 
|  | 40 | + | 
|  | 41 | +    def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( | 
|  | 42 | +        self, | 
|  | 43 | +        kernel: str, | 
|  | 44 | +        args: list[str], | 
|  | 45 | +        device: str, | 
|  | 46 | +        *, | 
|  | 47 | +        debug_args: Optional[list[str]] = None, | 
|  | 48 | +    ): | 
|  | 49 | +        if kernel not in supported_fallback_kernels: | 
|  | 50 | +            missing_fallback_kernels.add(kernel) | 
|  | 51 | + | 
|  | 52 | +        original_generate_c_shim_extern_kernel_call( | 
|  | 53 | +            self, kernel, args, device, debug_args=debug_args | 
|  | 54 | +        ) | 
|  | 55 | + | 
|  | 56 | +    CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 57 | +        generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels | 
|  | 58 | +    ) | 
|  | 59 | +    try: | 
|  | 60 | +        yield | 
|  | 61 | +    finally: | 
|  | 62 | +        CppWrapperCpu.generate_c_shim_extern_kernel_call = ( | 
|  | 63 | +            original_generate_c_shim_extern_kernel_call | 
|  | 64 | +        ) | 
|  | 65 | + | 
|  | 66 | + | 
|  | 67 | +@final | 
|  | 68 | +class CudaBackend(BackendDetails): | 
|  | 69 | +    """ | 
|  | 70 | +    CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate | 
|  | 71 | +    optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices | 
|  | 72 | +    using the Executorch runtime. | 
|  | 73 | +    """ | 
|  | 74 | + | 
|  | 75 | +    @staticmethod | 
|  | 76 | +    def preprocess( | 
|  | 77 | +        edge_program: ExportedProgram, | 
|  | 78 | +        compile_specs: List[CompileSpec], | 
|  | 79 | +    ) -> PreprocessResult: | 
|  | 80 | +        # Move the edge_program from CPU to CUDA for aoti compile | 
|  | 81 | +        cuda_edge_program = move_to_device_pass(edge_program, "cuda") | 
|  | 82 | + | 
|  | 83 | +        edge_program_module = cuda_edge_program.module() | 
|  | 84 | + | 
|  | 85 | +        # Step 2: Grab all placeholders from the graph; last n should be user inputs | 
|  | 86 | +        user_input_names = cuda_edge_program.graph_signature.user_inputs | 
|  | 87 | +        user_input_placeholders = [] | 
|  | 88 | +        for node in cuda_edge_program.graph.nodes: | 
|  | 89 | +            if node.op == "placeholder" and node.name in user_input_names: | 
|  | 90 | +                user_input_placeholders.append(node.meta["val"]) | 
|  | 91 | + | 
|  | 92 | +        # Step 3: Create pseudo user input using torch.randn and the generated input sizes | 
|  | 93 | +        faked_user_inputs = [] | 
|  | 94 | +        for placeholder in user_input_placeholders: | 
|  | 95 | +            if isinstance(placeholder, torch.Tensor): | 
|  | 96 | +                # Generate fake input with same shape and dtype, on CUDA | 
|  | 97 | +                fake_input = torch.randn( | 
|  | 98 | +                    placeholder.shape, dtype=placeholder.dtype, device="cuda" | 
|  | 99 | +                ) | 
|  | 100 | +                faked_user_inputs.append(fake_input) | 
|  | 101 | + | 
|  | 102 | +        faked_user_inputs = tuple(faked_user_inputs) | 
|  | 103 | + | 
|  | 104 | +        # Create a temporary file path for the compiled shared library output | 
|  | 105 | +        output_path = tempfile.mktemp(suffix=".so", prefix="aoti_") | 
|  | 106 | + | 
|  | 107 | +        options: dict[str, typing.Any] = { | 
|  | 108 | +            # Embed CUDA kernel binaries directly into the compiled shared object | 
|  | 109 | +            "aot_inductor.embed_kernel_binary": True, | 
|  | 110 | +            # Do not link against the full PyTorch/libtorch library | 
|  | 111 | +            "aot_inductor.link_libtorch": False, | 
|  | 112 | +            # Package model constants and other generated files directly in the shared object (.so) file | 
|  | 113 | +            "aot_inductor.package_constants_in_so": True, | 
|  | 114 | +            # Specify the output file path for the compiled shared object | 
|  | 115 | +            "aot_inductor.output_path": output_path, | 
|  | 116 | +            # Enable maximum automatic tuning for optimal performance | 
|  | 117 | +            "max_autotune": True, | 
|  | 118 | +            # Use TRITON for GEMM (General Matrix Multiply) operations tuning only to avoid using operators in libtorch | 
|  | 119 | +            "max_autotune_gemm_backends": "TRITON", | 
|  | 120 | +            # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch | 
|  | 121 | +            "max_autotune_conv_backends": "TRITON", | 
|  | 122 | +        } | 
|  | 123 | + | 
|  | 124 | +        with collect_unsupported_fallback_kernels(): | 
|  | 125 | +            _ = torch._inductor.aot_compile(edge_program_module, faked_user_inputs, options=options)  # type: ignore[arg-type] | 
|  | 126 | +            if len(missing_fallback_kernels) > 0: | 
|  | 127 | +                formatted_kernels = "\n  - ".join(sorted(missing_fallback_kernels)) | 
|  | 128 | +                raise RuntimeError( | 
|  | 129 | +                    f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n  - {formatted_kernels}\n" | 
|  | 130 | +                    "Please add them to the AOTI backend." | 
|  | 131 | +                ) | 
|  | 132 | + | 
|  | 133 | +        with open(output_path, "rb") as f: | 
|  | 134 | +            so_data = f.read() | 
|  | 135 | + | 
|  | 136 | +        named_data_store = NamedDataStore() | 
|  | 137 | +        named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob") | 
|  | 138 | + | 
|  | 139 | +        # Clean up the temporary output file | 
|  | 140 | +        os.remove(output_path) | 
|  | 141 | + | 
|  | 142 | +        return PreprocessResult( | 
|  | 143 | +            processed_bytes=b"", | 
|  | 144 | +            debug_handle_map={}, | 
|  | 145 | +            data_store_output=named_data_store.get_named_data_store_output(), | 
|  | 146 | +        ) | 
0 commit comments