diff --git a/backends/aoti/aoti_backend.py b/backends/aoti/aoti_backend.py new file mode 100644 index 00000000000..6c1a8a8661c --- /dev/null +++ b/backends/aoti/aoti_backend.py @@ -0,0 +1,261 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import os +import typing +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional, Set + +import torch +from executorch.backends.aoti.passes.replace_view_copy_with_view import ( + ReplaceViewCopyWithViewPass, +) +from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.exir._warnings import experimental +from executorch.exir.backend.backend_details import ( + BackendDetails, + ExportedProgram, + PreprocessResult, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu +from torch.export.passes import move_to_device_pass + + +class COMPILE_SPEC_KEYS(Enum): + METHOD_NAME = "method_name" + + +@experimental( + "This API and all of aoti-driven backend related functionality are experimental." +) +class AotiBackend(BackendDetails, ABC): + """ + Base backend class for AOTInductor-based backends. + + This class provides common functionality for compiling models using AOTInductor + with different device targets (CUDA, Metal/MPS, etc.). + """ + + @staticmethod + @abstractmethod + def get_device_name() -> str: + """Return the device name for this backend (e.g., 'cuda', 'mps').""" + pass + + @staticmethod + @abstractmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + """Return the set of supported fallback kernels for this backend.""" + pass + + @staticmethod + @abstractmethod + def get_decomposition_table() -> Dict[Any, Any]: + """Return the decomposition table for this backend.""" + pass + + @staticmethod + @abstractmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + """Return the AOTInductor compilation options for this backend.""" + pass + + @classmethod + @contextlib.contextmanager + def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]): + """ + Context manager to collect unsupported fallback kernels during compilation. + Monitors both extern kernel calls and runtime lookup. + """ + supported_kernels = cls.get_supported_fallback_kernels() + + original_generate_c_shim_extern_kernel_call = ( + CppWrapperCpu.generate_c_shim_extern_kernel_call + ) + original_generate_fallback_kernel_with_runtime_lookup_aot = ( + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot + ) + + def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( + self, + kernel: str, + args: list[str], + device: str, + *, + debug_args: Optional[list[str]] = None, + debug_handle: Optional[int] = None, + ): + if kernel not in supported_kernels: + missing_fallback_kernels.add(kernel) + + original_generate_c_shim_extern_kernel_call( + self, + kernel, + args, + device, + debug_args=debug_args, + debug_handle=debug_handle, + ) + + def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( + self, + op_overload, + raw_args, + output_args, + raw_outputs, + ): + kernel_name = getattr(op_overload, "_name", str(op_overload)) + if kernel_name not in supported_kernels: + missing_fallback_kernels.add(kernel_name) + + original_generate_fallback_kernel_with_runtime_lookup_aot( + self, op_overload, raw_args, output_args, raw_outputs + ) + + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels + + try: + yield + finally: + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( + original_generate_c_shim_extern_kernel_call + ) + CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( + original_generate_fallback_kernel_with_runtime_lookup_aot + ) + + @classmethod + def preprocess( + cls, + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Preprocess the edge program and compile it using AOTInductor. + Weights are always separated from the SO file. + """ + device_name = cls.get_device_name() + decomposition_table = cls.get_decomposition_table() + options = cls.get_aoti_compile_options() + + # Move the edge_program to the target device + device_edge_program = move_to_device_pass(edge_program, device_name) + + # Replace view_copy with view + ReplaceViewCopyWithViewPass()(device_edge_program.graph_module) + + # Run decompositions if any + if decomposition_table: + device_edge_program = device_edge_program.run_decompositions( + decomposition_table + ) + + edge_program_module = device_edge_program.module() + + # Grab all input placeholders from the graph + user_input_names = device_edge_program.graph_signature.user_inputs + user_input_placeholders = [] + for node in device_edge_program.graph.nodes: + if node.op == "placeholder" and node.name in user_input_names: + user_input_placeholders.append(node.meta["val"]) + + # Track missing fallback kernels + missing_fallback_kernels: Set[str] = set() + + # Compile with fallback kernel collection + with cls.collect_unsupported_fallback_kernels( + missing_fallback_kernels + ), torch.no_grad(): + paths = torch._inductor.aot_compile( + edge_program_module, tuple(user_input_placeholders), options=options + ) + + if len(missing_fallback_kernels) > 0: + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) + method_name = cls.method_name_from_compile_specs(compile_specs) + raise RuntimeError( + f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" + "Please add them to the AOTI backend." + ) + + # Extract paths - weights are always separated + so_path = None + blob_path = None + + if isinstance(paths, list): + for path in paths: + if path.endswith(".wrapper.so"): + so_path = path + elif path.endswith(".wrapper_weights.blob"): + blob_path = path + else: + so_path = paths + + if so_path is None or blob_path is None: + raise RuntimeError( + f"Could not find required files in compiled paths, got {paths}" + ) + + # Read SO file + with open(so_path, "rb") as f: + so_data = f.read() + + # Read weights blob + with open(blob_path, "rb") as f: + blob_data = f.read() + + # Create named data store + named_data_store = NamedDataStore() + method_name = cls.method_name_from_compile_specs(compile_specs) + + # Add SO and weights blob separately + named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) + weights_blob_data_type = f"aoti_{device_name}_blob" + named_data_store.add_named_data( + method_name + "_weights_blob", blob_data, 1, weights_blob_data_type + ) + + # Clean up the generated files + os.remove(so_path) + os.remove(blob_path) + + return PreprocessResult( + processed_bytes=b"", + debug_handle_map={}, + data_store_output=named_data_store.get_named_data_store_output(), + ) + + @staticmethod + def generate_method_name_compile_spec( + method_name: str, + ) -> CompileSpec: + """ + Generate a CompileSpec for the given method name. + """ + return CompileSpec( + COMPILE_SPEC_KEYS.METHOD_NAME.value, + method_name.encode("utf-8"), + ) + + @staticmethod + def method_name_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> str: + """ + Extract the method name from the compile specs. + """ + for spec in compile_specs: + if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: + return spec.value.decode("utf-8") + raise RuntimeError( + f"Could not find method name in compile specs: {compile_specs}" + ) diff --git a/backends/aoti/targets.bzl b/backends/aoti/targets.bzl index 560cf52e06f..193b82fdd16 100644 --- a/backends/aoti/targets.bzl +++ b/backends/aoti/targets.bzl @@ -16,6 +16,23 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "aoti_backend", + srcs = [ + "aoti_backend.py", + ], + visibility = [ + "//executorch/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/backends/aoti/passes:passes", + "//executorch/exir/_serialize:lib", + "//executorch/exir/backend:backend_details", + "//executorch/exir/backend:compile_spec_schema", + ], + ) + # AOTI common shims functionality runtime.cxx_library( name = "common_shims", diff --git a/backends/apple/metal/metal_backend.py b/backends/apple/metal/metal_backend.py index 7d1a5496be3..d73639beb54 100644 --- a/backends/apple/metal/metal_backend.py +++ b/backends/apple/metal/metal_backend.py @@ -4,107 +4,44 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum +from typing import Any, Dict, final -from typing import Any, Dict, final, List, Optional, Set - -import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) -from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu -from torch.export.passes import move_to_device_pass - - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "aoti_torch_mps_convolution": None, - "aoti_torch_mps_mm_out": None, - "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - debug_handle: Optional[int] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args, debug_handle=debug_handle - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) @final @experimental( "This API and all of Metal backend related functionality are experimental." ) -class MetalBackend(BackendDetails): - @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - print("entering the lowerable parts in MetalBackend.preprocess....") - # Move the edge_program from CPU to MPS for aoti compile - mps_edge_program = move_to_device_pass(edge_program, "mps") +class MetalBackend(AotiBackend): + """ + MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate + optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices + using the Executorch runtime. + """ - # replace slice_copy with slice - ReplaceViewCopyWithViewPass()(mps_edge_program.graph_module) + @staticmethod + def get_device_name() -> str: + return "mps" - edge_program_module = mps_edge_program.module() + @staticmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + return { + "aoti_torch_mps_addmm_out": None, + "aoti_torch_mps_convolution": None, + "aoti_torch_mps_mm_out": None, + "at::_ops::_scaled_dot_product_attention_math_for_mps::call": None, + } - # Grab all input placeholders from the graph - user_input_names = mps_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in mps_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @staticmethod + def get_decomposition_table() -> Dict[Any, Any]: + return {} - # Base options for all devices - options: dict[str, typing.Any] = { + @staticmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + return { # Do not link against the full PyTorch/libtorch library "aot_inductor.link_libtorch": False, # Separate weight constants from the .so file @@ -117,83 +54,3 @@ def preprocess( # "aot_inductor.debug_compile": True, # "aot_inductor.force_mmap_weights": False, } - - with collect_unsupported_fallback_kernels(): - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = MetalBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_metal_blob" - ) - - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Generates a CompileSpec for the given method name. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/backends/cuda/TARGETS b/backends/cuda/TARGETS index 94af87bbaed..6ff8b8625fd 100644 --- a/backends/cuda/TARGETS +++ b/backends/cuda/TARGETS @@ -11,11 +11,7 @@ runtime.python_library( "//executorch/...", ], deps = [ - "//caffe2:torch", - "//executorch/backends/aoti/passes:passes", - "//executorch/exir/_serialize:lib", - "//executorch/exir/backend:backend_details", - "//executorch/exir/backend:compile_spec_schema", + "//executorch/backends/aoti:aoti_backend", ], ) diff --git a/backends/cuda/cuda_backend.py b/backends/cuda/cuda_backend.py index f8482835ea5..0ba45e44060 100644 --- a/backends/cuda/cuda_backend.py +++ b/backends/cuda/cuda_backend.py @@ -4,111 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib -import os import typing -from enum import Enum - -from typing import Any, Dict, final, List, Optional, Set +from typing import Any, Dict, final import torch -from executorch.backends.aoti.passes.replace_view_copy_with_view import ( - ReplaceViewCopyWithViewPass, -) -from executorch.exir._serialize._named_data_store import NamedDataStore +from executorch.backends.aoti.aoti_backend import AotiBackend from executorch.exir._warnings import experimental -from executorch.exir.backend.backend_details import ( - BackendDetails, - ExportedProgram, - PreprocessResult, -) -from executorch.exir.backend.compile_spec_schema import CompileSpec -from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.decomposition import conv1d_to_conv2d -from torch.export.passes import move_to_device_pass -from torch.nn.attention import SDPBackend - -cuda_decomposition_table = { - torch.ops.aten.conv1d.default: conv1d_to_conv2d, -} - -# exist fallback operators in et namespace; -supported_fallback_kernels: Dict[str, Any] = { - "at::_ops::_weight_int4pack_mm::call": None, -} - -# required fallback kernels but not supported -missing_fallback_kernels: Set[str] = set() - - -class COMPILE_SPEC_KEYS(Enum): - METHOD_NAME = "method_name" - - -# context manager for non-fallback guarantee -# it will raise exception when generating fallback kernels during aoti compile -@contextlib.contextmanager -def collect_unsupported_fallback_kernels(): - original_generate_c_shim_extern_kernel_call = ( - CppWrapperCpu.generate_c_shim_extern_kernel_call - ) - original_generate_fallback_kernel_with_runtime_lookup_aot = ( - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot - ) - - def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( - self, - kernel: str, - args: list[str], - device: str, - *, - debug_args: Optional[list[str]] = None, - ): - if kernel not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel) - - original_generate_c_shim_extern_kernel_call( - self, kernel, args, device, debug_args=debug_args - ) - - def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( - self, - op_overload, - raw_args, - output_args, - raw_outputs, - ): - # Extract kernel name for collection - kernel_name = getattr(op_overload, "_name", str(op_overload)) - if kernel_name not in supported_fallback_kernels: - missing_fallback_kernels.add(kernel_name) - - original_generate_fallback_kernel_with_runtime_lookup_aot( - self, op_overload, raw_args, output_args, raw_outputs - ) - - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels - ) - try: - yield - finally: - CppWrapperCpu.generate_c_shim_extern_kernel_call = ( - original_generate_c_shim_extern_kernel_call - ) - CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = ( - original_generate_fallback_kernel_with_runtime_lookup_aot - ) @final @experimental( "This API and all of cuda backend related functionality are experimental." ) -class CudaBackend(BackendDetails): +class CudaBackend(AotiBackend): """ CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices @@ -116,30 +25,24 @@ class CudaBackend(BackendDetails): """ @staticmethod - def preprocess( - edge_program: ExportedProgram, - compile_specs: List[CompileSpec], - ) -> PreprocessResult: - # Move the edge_program from CPU to CUDA for aoti compile - cuda_edge_program = move_to_device_pass(edge_program, "cuda") - - # replace slice_copy.Tensor with slice.Tensor, select_copy.int with select.int - ReplaceViewCopyWithViewPass()(cuda_edge_program.graph_module) + def get_device_name() -> str: + return "cuda" - cuda_edge_program = cuda_edge_program.run_decompositions( - cuda_decomposition_table - ) - - edge_program_module = cuda_edge_program.module() + @staticmethod + def get_supported_fallback_kernels() -> Dict[str, Any]: + return { + "at::_ops::_weight_int4pack_mm::call": None, + } - # Grab all input placeholders from the graph - user_input_names = cuda_edge_program.graph_signature.user_inputs - user_input_placeholders = [] - for node in cuda_edge_program.graph.nodes: - if node.op == "placeholder" and node.name in user_input_names: - user_input_placeholders.append(node.meta["val"]) + @staticmethod + def get_decomposition_table() -> Dict[Any, Any]: + return { + torch.ops.aten.conv1d.default: conv1d_to_conv2d, + } - options: dict[str, typing.Any] = { + @staticmethod + def get_aoti_compile_options() -> Dict[str, typing.Any]: + return { # Disable this to support sdpa decomposition # TODO(gasoonjia): remove it after pin bump to latest pytorch "loop_ordering_after_fusion": False, @@ -161,88 +64,3 @@ def preprocess( # Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch "max_autotune_conv_backends": "TRITON", } - - with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel( - [ - SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`. - ] - ), torch.no_grad(): - # torch._logging.set_logs(post_grad_graphs=True) - # Here we should expect 1 so file and 1 weight blob in the same directory. - paths = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type] - if len(missing_fallback_kernels) > 0: - formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) - raise RuntimeError( - f"Method {CudaBackend.method_name_from_compile_specs(compile_specs)} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" - "Please add them to the AOTI backend." - ) - - # Extract the .so and .blob paths from the returned list - so_path = None - blob_path = None - for path in paths: - if path.endswith(".wrapper.so"): - so_path = path - elif path.endswith(".wrapper_weights.blob"): - blob_path = path - - if so_path is None or blob_path is None: - raise RuntimeError( - f"Could not find required files in compiled paths, got {paths}" - ) - - # pyre-ignorep[6]: Incompatible parameter type - with open(so_path, "rb") as f: - so_data = f.read() - - named_data_store = NamedDataStore() - method_name = CudaBackend.method_name_from_compile_specs(compile_specs) - - # Keep the so file in the NamedDataStore, so that it can be packaged into the .pte file. - named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None) - - # Add weights blob to named data store - with open(blob_path, "rb") as f: - blob_data = f.read() - named_data_store.add_named_data( - method_name + "_weights_blob", blob_data, 1, "aoti_cuda_blob" - ) - # Clean up the weights blob file - os.remove(blob_path) - - # Clean up the generated so file; it has been packaged into the NamedDataStore - # pyre-ignorep[6]: Incompatible parameter type - os.remove(so_path) - - return PreprocessResult( - processed_bytes=b"", - debug_handle_map={}, - data_store_output=named_data_store.get_named_data_store_output(), - ) - - @staticmethod - def generate_method_name_compile_spec( - method_name: str, - ) -> CompileSpec: - """ - Returns the compile spec representing the model compute precision, for additional details - please refer to the documentation for ``coremltools.precision``. - """ - return CompileSpec( - COMPILE_SPEC_KEYS.METHOD_NAME.value, - method_name.encode("utf-8"), - ) - - @staticmethod - def method_name_from_compile_specs( - compile_specs: List[CompileSpec], - ) -> str: - """ - Returns the method name from the compile specs. - """ - for spec in compile_specs: - if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value: - return spec.value.decode("utf-8") - raise RuntimeError( - f"Could not find method name in compile specs: {compile_specs}" - ) diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index dd8d97d66ac..3a072e03599 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -10,7 +10,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from functools import singledispatch -from typing import Dict, Generator, List, Mapping +from typing import Dict, Generator, List, Mapping, Set import torch @@ -581,9 +581,21 @@ def lower_all_submodules_to_backend( for method_name, call_submodule_nodes in method_to_submodules_nodes.items() } + def _get_all_final_backend_details_subclasses(cls) -> Set[type]: + subclasses = set() + if len(cls.__subclasses__()) == 0: + return {cls} + else: + for subclass in cls.__subclasses__(): + # Recursively check subclasses + subclasses.update(_get_all_final_backend_details_subclasses(subclass)) + return subclasses + backend_name_to_subclass = { - subclass.__name__: subclass for subclass in BackendDetails.__subclasses__() + subclass.__name__: subclass + for subclass in _get_all_final_backend_details_subclasses(BackendDetails) } + if backend_id not in backend_name_to_subclass: raise NotImplementedError(f"Backend {backend_id} was not found.")