Skip to content

Commit 1e90fd0

Browse files
committed
merge lastest update
1 parent 4be608d commit 1e90fd0

File tree

3 files changed

+104
-30
lines changed

3 files changed

+104
-30
lines changed

backends/aoti/aoti_backend.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,38 @@ class AotiBackend(ABC):
4242
BackendDetails and AotiBackend to get the full functionality.
4343
"""
4444

45-
@staticmethod
45+
@classmethod
4646
@abstractmethod
47-
def get_device_name() -> str:
47+
def get_device_name(cls) -> str:
4848
"""Return the device name for this backend (e.g., 'cuda', 'metal')."""
4949
pass
5050

51-
@staticmethod
51+
@classmethod
5252
@abstractmethod
53-
def get_supported_fallback_kernels() -> Dict[str, Any]:
53+
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
5454
"""Return the set of supported fallback kernels for this backend."""
5555
pass
5656

57-
@staticmethod
57+
@classmethod
5858
@abstractmethod
59-
def get_decomposition_table() -> Dict[Any, Any]:
59+
def get_decomposition_table(cls) -> Dict[Any, Any]:
6060
"""Return the decomposition table for this backend."""
6161
pass
6262

63-
@staticmethod
63+
@classmethod
6464
@abstractmethod
65-
def get_aoti_compile_options() -> Dict[str, typing.Any]:
65+
def get_aoti_compile_options(
66+
cls, compile_specs: List[CompileSpec]
67+
) -> Dict[str, typing.Any]:
6668
"""Return the AOTInductor compilation options for this backend."""
6769
pass
6870

71+
@classmethod
72+
@abstractmethod
73+
def get_custom_passes(cls) -> List[typing.Any]:
74+
"""Return the list of custom passes to apply after ReplaceViewCopyWithViewPass and before decomposition."""
75+
pass
76+
6977
@classmethod
7078
@contextlib.contextmanager
7179
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
@@ -145,7 +153,7 @@ def preprocess(
145153
"""
146154
device_name = cls.get_device_name()
147155
decomposition_table = cls.get_decomposition_table()
148-
options = cls.get_aoti_compile_options()
156+
options = cls.get_aoti_compile_options(compile_specs)
149157

150158
# Move the edge_program to the target device
151159
device_edge_program = move_to_device_pass(
@@ -155,6 +163,11 @@ def preprocess(
155163
# Replace view_copy with view
156164
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
157165

166+
# Apply custom backend-specific passes
167+
custom_passes = cls.get_custom_passes()
168+
for custom_pass in custom_passes:
169+
custom_pass(device_edge_program.graph_module)
170+
158171
# Run decompositions if any
159172
if decomposition_table:
160173
device_edge_program = device_edge_program.run_decompositions(
@@ -236,8 +249,9 @@ def preprocess(
236249
data_store_output=named_data_store.get_named_data_store_output(),
237250
)
238251

239-
@staticmethod
252+
@classmethod
240253
def generate_method_name_compile_spec(
254+
cls,
241255
method_name: str,
242256
) -> CompileSpec:
243257
"""
@@ -248,8 +262,9 @@ def generate_method_name_compile_spec(
248262
method_name.encode("utf-8"),
249263
)
250264

251-
@staticmethod
265+
@classmethod
252266
def method_name_from_compile_specs(
267+
cls,
253268
compile_specs: List[CompileSpec],
254269
) -> str:
255270
"""

backends/apple/metal/metal_backend.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import typing
8-
from typing import Any, Dict, final
8+
from typing import Any, Dict, final, List
99

1010
from executorch.backends.aoti.aoti_backend import AotiBackend
1111
from executorch.exir._warnings import experimental
1212
from executorch.exir.backend.backend_details import BackendDetails
13+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1314

1415

1516
@final
@@ -23,25 +24,34 @@ class MetalBackend(AotiBackend, BackendDetails):
2324
using the Executorch runtime.
2425
"""
2526

26-
@staticmethod
27-
def get_device_name() -> str:
27+
@classmethod
28+
def get_device_name(cls) -> str:
2829
return "metal"
2930

30-
@staticmethod
31-
def get_supported_fallback_kernels() -> Dict[str, Any]:
31+
@classmethod
32+
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3233
return {
3334
"aoti_torch_mps_addmm_out": None,
3435
"aoti_torch_mps_convolution": None,
3536
"aoti_torch_mps_mm_out": None,
3637
"at::_ops::_scaled_dot_product_attention_math_for_mps::call": None,
3738
}
3839

39-
@staticmethod
40-
def get_decomposition_table() -> Dict[Any, Any]:
40+
@classmethod
41+
def get_decomposition_table(cls) -> Dict[Any, Any]:
4142
return {}
4243

43-
@staticmethod
44-
def get_aoti_compile_options() -> Dict[str, typing.Any]:
44+
@classmethod
45+
def get_custom_passes(cls) -> List[typing.Any]:
46+
"""Return Metal-specific passes (currently none)"""
47+
return []
48+
49+
@classmethod
50+
def get_aoti_compile_options(
51+
cls, compile_specs: List[CompileSpec]
52+
) -> Dict[str, typing.Any]:
53+
"""Get AOTI compile options for Metal backend."""
54+
_ = compile_specs # Unused, but required by interface
4555
return {
4656
# Do not link against the full PyTorch/libtorch library
4757
"aot_inductor.link_libtorch": False,

backends/cuda/cuda_backend.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import typing
8-
from typing import Any, Dict, final
8+
from importlib import resources
9+
from typing import Any, Dict, final, List
910

1011
import torch
1112
from executorch.backends.aoti.aoti_backend import AotiBackend
13+
from executorch.backends.cuda.triton.replacement_pass import (
14+
ReplaceEdgeOpWithTritonOpPass,
15+
)
1216
from executorch.exir._warnings import experimental
1317
from executorch.exir.backend.backend_details import BackendDetails
18+
from executorch.exir.backend.compile_spec_schema import CompileSpec
1419
from torch._inductor.decomposition import conv1d_to_conv2d
1520

1621

@@ -25,25 +30,37 @@ class CudaBackend(AotiBackend, BackendDetails):
2530
using the Executorch runtime.
2631
"""
2732

28-
@staticmethod
29-
def get_device_name() -> str:
33+
@classmethod
34+
def get_device_name(cls) -> str:
3035
return "cuda"
3136

32-
@staticmethod
33-
def get_supported_fallback_kernels() -> Dict[str, Any]:
37+
@classmethod
38+
def get_supported_fallback_kernels(cls) -> Dict[str, Any]:
3439
return {
3540
"at::_ops::_weight_int4pack_mm::call": None,
3641
}
3742

38-
@staticmethod
39-
def get_decomposition_table() -> Dict[Any, Any]:
43+
@classmethod
44+
def get_decomposition_table(cls) -> Dict[Any, Any]:
4045
return {
4146
torch.ops.aten.conv1d.default: conv1d_to_conv2d,
4247
}
4348

44-
@staticmethod
45-
def get_aoti_compile_options() -> Dict[str, typing.Any]:
46-
return {
49+
@classmethod
50+
def get_custom_passes(cls) -> List[typing.Any]:
51+
"""Return CUDA-specific passes: ReplaceEdgeOpWithTritonOpPass"""
52+
return [ReplaceEdgeOpWithTritonOpPass()]
53+
54+
@classmethod
55+
def get_aoti_compile_options(
56+
cls, compile_specs: List[CompileSpec]
57+
) -> Dict[str, typing.Any]:
58+
"""
59+
Get AOTI compile options for CUDA backend.
60+
Options may vary based on platform (Linux vs Windows).
61+
"""
62+
# Base options for all platforms
63+
options: Dict[str, typing.Any] = {
4764
# Disable this to support sdpa decomposition
4865
# TODO(gasoonjia): remove it after pin bump to latest pytorch
4966
"loop_ordering_after_fusion": False,
@@ -65,3 +82,35 @@ def get_aoti_compile_options() -> Dict[str, typing.Any]:
6582
# Use TRITON backend for convolution operations tuning only to avoid using operators in libtorch
6683
"max_autotune_conv_backends": "TRITON",
6784
}
85+
86+
# Parse compile_specs to check for platform
87+
platform = "linux"
88+
shim_library_path = None
89+
for spec in compile_specs:
90+
if spec.key == "platform":
91+
platform = spec.value.decode("utf-8")
92+
if spec.key == "shim_library_path":
93+
shim_library_path = spec.value.decode("utf-8")
94+
95+
# Add platform-specific options
96+
if platform == "windows":
97+
# For Windows, get default shim library path if not provided
98+
if shim_library_path is None:
99+
lib_dir = resources.files("executorch").joinpath("data/lib")
100+
shim_library_path = str(lib_dir)
101+
102+
options.update(
103+
{
104+
"aot_inductor.cross_target_platform": "windows",
105+
"aot_inductor.aoti_shim_library": "aoti_cuda_shims",
106+
"aot_inductor.aoti_shim_library_path": shim_library_path,
107+
"aot_inductor.precompile_headers": False,
108+
}
109+
)
110+
else:
111+
# Linux platform
112+
assert (
113+
shim_library_path is None
114+
), "shim_library_path should not be set for Linux"
115+
116+
return options

0 commit comments

Comments
 (0)