diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py new file mode 100644 index 00000000000..2d396a296bd --- /dev/null +++ b/backends/aoti/aoti_backend.py @@ -0,0 +1,278 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +import torch +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +@experimental( + "This API and all of aoti-driven backend related functionality are experimental." +) +class AotiBackend(ABC): + """ + Base mixin class for AOTInductor-based backends. + + This class provides common functionality for compiling models using AOTInductor + with different device targets (CUDA, Metal, etc.). + + This is a mixin class, not an actual backend object, for aoti-driven backends. + Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both + BackendDetails and AotiBackend to get the full functionality. + """ + + @classmethod + @abstractmethod + def get_device_name(cls) -> str: + """Return the device name for this backend (e.g., 'cuda', 'metal').""" + pass + + @classmethod + @abstractmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + """Return the set of supported fallback kernels for this backend.""" + pass + + @classmethod + @abstractmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + """Return the decomposition table for this backend.""" + pass + + @classmethod + @abstractmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Return the AOTInductor compilation options for this backend.""" + pass + + @classmethod + @abstractmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition.""" + pass + + @classmethod + @contextlib.contextmanager + def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): + """ + Context manager to collect unsupported fallback kernels during compilation. + Monitors both extern kernel calls and runtime lookup. + """ + supported_kernels = cls.get_supported_fallback_kernels() + + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + original_generate_fallback_kernel_with_runtime_lookup_aot = ( + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, + kernel, + args, + device, + debug_args=debug_args, + debug_handle=debug_handle, + ) + + def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( + self, + op_overload, + raw_args, + output_args, + raw_outputs, + ): + kernel_name = getattr(op_overload, "_name", str(op_overload)) + if kernel_name not in supported_kernels: + missing_fallback_kernels.add(kernel_name) + + original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, raw_args, output_args, raw_outputs + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels + + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( + original_generate_fallback_kernel_with_runtime_lookup_aot + ) + + @classmethod + def preprocess( + cls, + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the edge program and compile it using AOTInductor. + Weights are always separated from the SO file. + """ + device_name = cls.get_device_name() + decomposition_table = cls.get_decomposition_table() + options = cls.get_aoti_compile_options(compile_specs) + + # Move the edge_program to the target device + device_edge_program = move_to_device_pass( + edge_program, device_name if device_name != "metal" else "mps" + ) + + # Replace view_copy with view + ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + + # Apply custom backend-specific passes + custom_passes = cls.get_custom_passes() + for custom_pass in custom_passes: + custom_pass(device_edge_program.graph_module) + + # Run decompositions if any + if decomposition_table: + device_edge_program = device_edge_program.run_decompositions( + decomposition_table + ) + + edge_program_module = device_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = device_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in device_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Track missing fallback kernels + missing_fallback_kernels: Set[str] = set() + + # Compile with fallback kernel collection + with cls.collect_unsupported_fallback_kernels( + missing_fallback_kernels + ), torch.no_grad(): + paths = torch._inductor.aot_compile( + edge_program_module, tuple(user_input_placeholders), options=options + ) + + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + method_name = cls.method_name_from_compile_specs(compile_specs) + raise RuntimeError( + f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # Extract paths - weights are always separated + so_path = None + blob_path = None + + if isinstance(paths, list): + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + else: + so_path = paths + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + + # Read SO file + with open(so_path, "rb") as f: + so_data = f.read() + + # Read weights blob + with open(blob_path, "rb") as f: + blob_data = f.read() + + # Create named data store + named_data_store = NamedDataStore() + method_name = cls.method_name_from_compile_specs(compile_specs) + + # Add SO and weights blob separately + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + weights_blob_data_type = f"aoti_{device_name}_blob" + named_data_store.add_named_data( + method_name + "_weights_blob", blob_data, 1, weights_blob_data_type + ) + + # Clean up the generated files + os.remove(so_path) + os.remove(blob_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @classmethod + def generate_method_name_compile_spec( + cls, + method_name: str, + ) -> CompileSpec: + """ + Generate a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @classmethod + def method_name_from_compile_specs( + cls, + compile_specs: List[CompileSpec], + ) -> str: + """ + Extract the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index be5fe490721..327bef8cc53 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -16,6 +16,23 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "aoti_backend", + srcs = [ + "aoti_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + ], + ) + # AOTI common shims functionality runtime.cxx_library( name = "common_shims", diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 7d1a5496be3..1b27b027fc2 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -4,107 +4,55 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum +from typing import Any, Dict, final, List -from typing import Any, Dict, final, List, Optional, Set - -import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu -from torch.export.passes import move_to_device_pass - - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "aoti_torch_mps_convolution": None, - "aoti_torch_mps_mm_out": None, - "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) @final @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(BackendDetails): - @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - print("entering the lowerable parts in MetalBackend.preprocess....") - # Move the edge_program from CPU to MPS for aoti compile - mps_edge_program = move_to_device_pass(edge_program, "mps") - - # replace slice_copy with slice - ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) - - edge_program_module = mps_edge_program.module() - - # Grab all input placeholders from the graph - user_input_names = mps_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in mps_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) +class MetalBackend(AotiBackend, BackendDetails): + """ + MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate + optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices + using the Executorch runtime. + """ + + @classmethod + def get_device_name(cls) -> str: + return "metal" + + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + } - # Base options for all devices - options: dict[str, typing.Any] = { + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return {} + + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return Metal-specific passes (currently none)""" + return [] + + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """Get AOTI compile options for Metal backend.""" + _ = compile_specs # Unused, but required by interface + return { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, # Separate weight constants from the .so file @@ -117,83 +65,3 @@ def preprocess( # "aot_inductor.debug_compile": True, # "aot_inductor.force_mmap_weights": False, } - - with collect_unsupported_fallback_kernels(): - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = MetalBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob" - ) - - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Generates a CompileSpec for the given method name. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index d8256f77c41..3ae4eec6680 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -17,6 +17,7 @@ runtime.python_library( "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_details", "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/aoti:aoti_backend", ], ) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index 772e24c75b3..cc2d662b335 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -4,150 +4,63 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum from importlib import resources - -from typing import Any, Dict, final, List, Optional, Set +from typing import Any, Dict, final, List import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) - +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.backends.cuda.triton.replacement_pass import ( ReplaceEdgeOpWithTritonOpPass, ) -from executorch.exir._serialize._named_data_store import NamedDataStore from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) +from executorch.exir.backend.backend_details import BackendDetails from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.decomposition import conv1d_to_conv2d -from torch.export.passes import move_to_device_pass - - -cuda_decomposition_table = { - torch.ops.aten.conv1d.default: conv1d_to_conv2d, -} - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "at::_ops::_weight_int4pack_mm::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - original_generate_fallback_kernel_with_runtime_lookup_aot = ( - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args - ) - - def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( - self, - op_overload, - raw_args, - output_args, - raw_outputs, - ): - # Extract kernel name for collection - kernel_name = getattr(op_overload, "_name", str(op_overload)) - if kernel_name not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel_name) - - original_generate_fallback_kernel_with_runtime_lookup_aot( - self, op_overload, raw_args, output_args, raw_outputs - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - original_generate_fallback_kernel_with_runtime_lookup_aot - ) @final @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(BackendDetails): +class CudaBackend(AotiBackend, BackendDetails): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices using the Executorch runtime. """ - @staticmethod - def preprocess( # noqa: C901 - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - # Move the edge_program from CPU to CUDA for aoti compile - cuda_edge_program = move_to_device_pass(edge_program, "cuda") - - # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int - ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) - - # Replace aten ops with triton ops - ReplaceEdgeOpWithTritonOpPass()(cuda_edge_program.graph_module) + @classmethod + def get_device_name(cls) -> str: + return "cuda" - cuda_edge_program = cuda_edge_program.run_decompositions( - cuda_decomposition_table - ) + @classmethod + def get_supported_fallback_kernels(cls) -> Dict[str, Any]: + return { + "at::_ops::_weight_int4pack_mm::call": None, + } - edge_program_module = cuda_edge_program.module() + @classmethod + def get_decomposition_table(cls) -> Dict[Any, Any]: + return { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } - # Grab all input placeholders from the graph - user_input_names = cuda_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in cuda_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @classmethod + def get_custom_passes(cls) -> List[typing.Any]: + """Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass""" + return [ReplaceEdgeOpWithTritonOpPass()] - options: dict[str, typing.Any] = { + @classmethod + def get_aoti_compile_options( + cls, compile_specs: List[CompileSpec] + ) -> Dict[str, typing.Any]: + """ + Get AOTI compile options for CUDA backend. + Options may vary based on platform (Linux vs Windows). + """ + # Base options for all platforms + options: Dict[str, typing.Any] = { # Disable this to support sdpa decomposition # TODO(gasoonjia): remove it after pin bump to latest pytorch "loop_ordering_after_fusion": False, @@ -170,6 +83,7 @@ def preprocess( # noqa: C901 "max_autotune_conv_backends": "TRITON", } + # Parse compile_specs to check for platform platform = "linux" shim_library_path = None for spec in compile_specs: @@ -178,14 +92,13 @@ def preprocess( # noqa: C901 if spec.key == "shim_library_path": shim_library_path = spec.value.decode("utf-8") - assert platform == "linux" or platform == "windows" - if platform == "windows" and shim_library_path is None: - lib_dir = resources.files("executorch").joinpath("data/lib") - shim_library_path = str(lib_dir) - if platform == "linux": - assert shim_library_path is None - + # Add platform-specific options if platform == "windows": + # For Windows, get default shim library path if not provided + if shim_library_path is None: + lib_dir = resources.files("executorch").joinpath("data/lib") + shim_library_path = str(lib_dir) + options.update( { "aot_inductor.cross_target_platform": "windows", @@ -194,84 +107,10 @@ def preprocess( # noqa: C901 "aot_inductor.precompile_headers": False, } ) + else: + # Linux platform + assert ( + shim_library_path is None + ), "shim_library_path should not be set for Linux" - with collect_unsupported_fallback_kernels(), torch.no_grad(): - # torch._logging.set_logs(post_grad_graphs=True) - # Here we should expect 1 so file and 1 weight blob in the same directory. - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = CudaBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob" - ) - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Returns the compile spec representing the model compute precision, for additional details - please refer to the documentation for ``coremltools.precision``. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) + return options