Skip to content

Commit 90fc8c0

Browse files
committed
make aoti_backend not a real backend
1 parent c1e6a29 commit 90fc8c0

File tree

4 files changed

+13
-23
lines changed

4 files changed

+13
-23
lines changed

backends/aoti/aoti_backend.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
)
1818
from executorch.exir._serialize._named_data_store import NamedDataStore
1919
from executorch.exir._warnings import experimental
20-
from executorch.exir.backend.backend_details import (
21-
BackendDetails,
22-
ExportedProgram,
23-
PreprocessResult,
24-
)
20+
from executorch.exir.backend.backend_details import ExportedProgram, PreprocessResult
2521
from executorch.exir.backend.compile_spec_schema import CompileSpec
2622
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2723
from torch.export.passes import move_to_device_pass
@@ -34,12 +30,16 @@ class COMPILE_SPEC_KEYS(Enum):
3430
@experimental(
3531
"This API and all of aoti-driven backend related functionality are experimental."
3632
)
37-
class AotiBackend(BackendDetails, ABC):
33+
class AotiBackend(ABC):
3834
"""
39-
Base backend class for AOTInductor-based backends.
35+
Base mixin class for AOTInductor-based backends.
4036
4137
This class provides common functionality for compiling models using AOTInductor
4238
with different device targets (CUDA, Metal, etc.).
39+
40+
This is a mixin class, not an actual backend object, for aoti-driven backens.
41+
Concrete backends (e.g., CudaBackend, MetalBackend) should inherit from both
42+
BackendDetails and AotiBackend to get the full functionality.
4343
"""
4444

4545
@staticmethod

backends/apple/metal/metal_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
from typing import Any, Dict, final
99

1010
from executorch.backends.aoti.aoti_backend import AotiBackend
11+
from executorch.exir.backend.backend_details import BackendDetails
1112
from executorch.exir._warnings import experimental
1213

1314

1415
@final
1516
@experimental(
1617
"This API and all of Metal backend related functionality are experimental."
1718
)
18-
class MetalBackend(AotiBackend):
19+
class MetalBackend(BackendDetails, AotiBackend):
1920
"""
2021
MetalBackend is a backend that compiles a model to run on Metal/MPS devices. It uses the AOTInductor compiler to generate
2122
optimized Metal kernels for the model's operators with libtorch-free. The compiled model can be executed on Metal devices

backends/cuda/cuda_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
from executorch.backends.aoti.aoti_backend import AotiBackend
12+
from executorch.exir.backend.backend_details import BackendDetails
1213
from executorch.exir._warnings import experimental
1314
from torch._inductor.decomposition import conv1d_to_conv2d
1415

@@ -17,7 +18,7 @@
1718
@experimental(
1819
"This API and all of cuda backend related functionality are experimental."
1920
)
20-
class CudaBackend(AotiBackend):
21+
class CudaBackend(BackendDetails, AotiBackend):
2122
"""
2223
CudaBackend is a backend that compiles a model to run on CUDA devices. It uses the AOTInductor compiler to generate
2324
optimized CUDA kernels for the model's operators with libtorch-free. The compiled model can be executed on CUDA devices

exir/backend/backend_api.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from contextlib import contextmanager, nullcontext
1111
from dataclasses import dataclass
1212
from functools import singledispatch
13-
from typing import Dict, Generator, List, Mapping, Set
13+
from typing import Dict, Generator, List, Mapping
1414

1515
import torch
1616

@@ -581,21 +581,9 @@ def lower_all_submodules_to_backend(
581581
for method_name, call_submodule_nodes in method_to_submodules_nodes.items()
582582
}
583583

584-
def _get_all_final_backend_details_subclasses(cls) -> Set[type]:
585-
subclasses = set()
586-
if len(cls.__subclasses__()) == 0:
587-
return {cls}
588-
else:
589-
for subclass in cls.__subclasses__():
590-
# Recursively check subclasses
591-
subclasses.update(_get_all_final_backend_details_subclasses(subclass))
592-
return subclasses
593-
594584
backend_name_to_subclass = {
595-
subclass.__name__: subclass
596-
for subclass in _get_all_final_backend_details_subclasses(BackendDetails)
585+
subclass.__name__: subclass for subclass in BackendDetails.__subclasses__()
597586
}
598-
599587
if backend_id not in backend_name_to_subclass:
600588
raise NotImplementedError(f"Backend {backend_id} was not found.")
601589

0 commit comments

Comments
 (0)