Skip to content

Commit 31d7508

Browse files
committed
[aoti-backend-consolidation 2/3] backend.py
Pull Request resolved: #15430 # Summary This diff consolidates the backend functionality into a single target `//executorch/backends/aoti:aoti_backend` and simplifies the cuda backend target by making it dependent on the consolidated backend target. The following changes are made in this diff: * Creation of a new target `//executorch/backends/aoti:aoti_backend` in `fbcode/executorch/backends/aoti/targets.bzl` which includes the necessary dependencies for the AOTI backend. * Update of the `//executorch/backends/cuda:cuda_backend` target in `fbcode/executorch/backends/cuda/TARGETS` to depend on the new `//executorch/backends/aoti:aoti_backend` target instead of individual AOTI backend dependencies. * Creation of a new file `fbcode/executorch/backends/aoti/aoti_backend.py` which imports the necessary dependencies and passes for the AOTI backend. * Simplification of the `xplat/executorch/backends/cuda/cuda_backend.py` file by removing unnecessary imports and using the new `AotiBackend` class from the `aoti_backend.py` file. ghstack-source-id: 319365324 Differential Revision: [D85704977](https://our.internmc.facebook.com/intern/diff/D85704977/)
1 parent c56a824 commit 31d7508

File tree

5 files changed

+322
-373
lines changed

5 files changed

+322
-373
lines changed

backends/aoti/aoti_backend.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import contextlib
8+
import os
9+
import typing
10+
from abc import ABC, abstractmethod
11+
from enum import Enum
12+
from typing import Any, Dict, List, Optional, Set
13+
14+
import torch
15+
from executorch.backends.aoti.passes.replace_view_copy_with_view import (
16+
ReplaceViewCopyWithViewPass,
17+
)
18+
from executorch.exir._serialize._named_data_store import NamedDataStore
19+
from executorch.exir._warnings import experimental
20+
from executorch.exir.backend.backend_details import (
21+
BackendDetails,
22+
ExportedProgram,
23+
PreprocessResult,
24+
)
25+
from executorch.exir.backend.compile_spec_schema import CompileSpec
26+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
27+
from torch.export.passes import move_to_device_pass
28+
29+
30+
class COMPILE_SPEC_KEYS(Enum):
31+
METHOD_NAME = "method_name"
32+
33+
34+
@experimental(
35+
"This API and all of aoti-driven backend related functionality are experimental."
36+
)
37+
class AotiBackend(BackendDetails, ABC):
38+
"""
39+
Base backend class for AOTInductor-based backends.
40+
41+
This class provides common functionality for compiling models using AOTInductor
42+
with different device targets (CUDA, Metal/MPS, etc.).
43+
"""
44+
45+
@staticmethod
46+
@abstractmethod
47+
def get_device_name() -> str:
48+
"""Return the device name for this backend (e.g., 'cuda', 'mps')."""
49+
pass
50+
51+
@staticmethod
52+
@abstractmethod
53+
def get_supported_fallback_kernels() -> Dict[str, Any]:
54+
"""Return the set of supported fallback kernels for this backend."""
55+
pass
56+
57+
@staticmethod
58+
@abstractmethod
59+
def get_decomposition_table() -> Dict[Any, Any]:
60+
"""Return the decomposition table for this backend."""
61+
pass
62+
63+
@staticmethod
64+
@abstractmethod
65+
def get_aoti_compile_options() -> Dict[str, typing.Any]:
66+
"""Return the AOTInductor compilation options for this backend."""
67+
pass
68+
69+
@classmethod
70+
@contextlib.contextmanager
71+
def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str]):
72+
"""
73+
Context manager to collect unsupported fallback kernels during compilation.
74+
Monitors both extern kernel calls and runtime lookup.
75+
"""
76+
supported_kernels = cls.get_supported_fallback_kernels()
77+
78+
original_generate_c_shim_extern_kernel_call = (
79+
CppWrapperCpu.generate_c_shim_extern_kernel_call
80+
)
81+
original_generate_fallback_kernel_with_runtime_lookup_aot = (
82+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot
83+
)
84+
85+
def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels(
86+
self,
87+
kernel: str,
88+
args: list[str],
89+
device: str,
90+
*,
91+
debug_args: Optional[list[str]] = None,
92+
debug_handle: Optional[int] = None,
93+
):
94+
if kernel not in supported_kernels:
95+
missing_fallback_kernels.add(kernel)
96+
97+
original_generate_c_shim_extern_kernel_call(
98+
self,
99+
kernel,
100+
args,
101+
device,
102+
debug_args=debug_args,
103+
debug_handle=debug_handle,
104+
)
105+
106+
def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels(
107+
self,
108+
op_overload,
109+
raw_args,
110+
output_args,
111+
raw_outputs,
112+
):
113+
kernel_name = getattr(op_overload, "_name", str(op_overload))
114+
if kernel_name not in supported_kernels:
115+
missing_fallback_kernels.add(kernel_name)
116+
117+
original_generate_fallback_kernel_with_runtime_lookup_aot(
118+
self, op_overload, raw_args, output_args, raw_outputs
119+
)
120+
121+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
122+
generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels
123+
)
124+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels
125+
126+
try:
127+
yield
128+
finally:
129+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
130+
original_generate_c_shim_extern_kernel_call
131+
)
132+
CppWrapperCpu.generate_fallback_kernel_with_runtime_lookup_aot = (
133+
original_generate_fallback_kernel_with_runtime_lookup_aot
134+
)
135+
136+
@classmethod
137+
def preprocess(
138+
cls,
139+
edge_program: ExportedProgram,
140+
compile_specs: List[CompileSpec],
141+
) -> PreprocessResult:
142+
"""
143+
Preprocess the edge program and compile it using AOTInductor.
144+
Weights are always separated from the SO file.
145+
"""
146+
device_name = cls.get_device_name()
147+
decomposition_table = cls.get_decomposition_table()
148+
options = cls.get_aoti_compile_options()
149+
150+
# Move the edge_program to the target device
151+
device_edge_program = move_to_device_pass(edge_program, device_name)
152+
153+
# Replace view_copy with view
154+
ReplaceViewCopyWithViewPass()(device_edge_program.graph_module)
155+
156+
# Run decompositions if any
157+
if decomposition_table:
158+
device_edge_program = device_edge_program.run_decompositions(
159+
decomposition_table
160+
)
161+
162+
edge_program_module = device_edge_program.module()
163+
164+
# Grab all input placeholders from the graph
165+
user_input_names = device_edge_program.graph_signature.user_inputs
166+
user_input_placeholders = []
167+
for node in device_edge_program.graph.nodes:
168+
if node.op == "placeholder" and node.name in user_input_names:
169+
user_input_placeholders.append(node.meta["val"])
170+
171+
# Track missing fallback kernels
172+
missing_fallback_kernels: Set[str] = set()
173+
174+
# Compile with fallback kernel collection
175+
with cls.collect_unsupported_fallback_kernels(
176+
missing_fallback_kernels
177+
), torch.no_grad():
178+
paths = torch._inductor.aot_compile(
179+
edge_program_module, tuple(user_input_placeholders), options=options
180+
)
181+
182+
if len(missing_fallback_kernels) > 0:
183+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
184+
method_name = cls.method_name_from_compile_specs(compile_specs)
185+
raise RuntimeError(
186+
f"Method {method_name} missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
187+
"Please add them to the AOTI backend."
188+
)
189+
190+
# Extract paths - weights are always separated
191+
so_path = None
192+
blob_path = None
193+
194+
if isinstance(paths, list):
195+
for path in paths:
196+
if path.endswith(".wrapper.so"):
197+
so_path = path
198+
elif path.endswith(".wrapper_weights.blob"):
199+
blob_path = path
200+
else:
201+
so_path = paths
202+
203+
if so_path is None or blob_path is None:
204+
raise RuntimeError(
205+
f"Could not find required files in compiled paths, got {paths}"
206+
)
207+
208+
# Read SO file
209+
with open(so_path, "rb") as f:
210+
so_data = f.read()
211+
212+
# Read weights blob
213+
with open(blob_path, "rb") as f:
214+
blob_data = f.read()
215+
216+
# Create named data store
217+
named_data_store = NamedDataStore()
218+
method_name = cls.method_name_from_compile_specs(compile_specs)
219+
220+
# Add SO and weights blob separately
221+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, None)
222+
weights_blob_data_type = f"aoti_{device_name}_blob"
223+
named_data_store.add_named_data(
224+
method_name + "_weights_blob", blob_data, 1, weights_blob_data_type
225+
)
226+
227+
# Clean up the generated files
228+
os.remove(so_path)
229+
os.remove(blob_path)
230+
231+
return PreprocessResult(
232+
processed_bytes=b"",
233+
debug_handle_map={},
234+
data_store_output=named_data_store.get_named_data_store_output(),
235+
)
236+
237+
@staticmethod
238+
def generate_method_name_compile_spec(
239+
method_name: str,
240+
) -> CompileSpec:
241+
"""
242+
Generate a CompileSpec for the given method name.
243+
"""
244+
return CompileSpec(
245+
COMPILE_SPEC_KEYS.METHOD_NAME.value,
246+
method_name.encode("utf-8"),
247+
)
248+
249+
@staticmethod
250+
def method_name_from_compile_specs(
251+
compile_specs: List[CompileSpec],
252+
) -> str:
253+
"""
254+
Extract the method name from the compile specs.
255+
"""
256+
for spec in compile_specs:
257+
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
258+
return spec.value.decode("utf-8")
259+
raise RuntimeError(
260+
f"Could not find method name in compile specs: {compile_specs}"
261+
)

backends/aoti/targets.bzl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,23 @@ def define_common_targets():
1616
],
1717
)
1818

19+
runtime.python_library(
20+
name = "aoti_backend",
21+
srcs = [
22+
"aoti_backend.py",
23+
],
24+
visibility = [
25+
"//executorch/...",
26+
],
27+
deps = [
28+
"//caffe2:torch",
29+
"//executorch/backends/aoti/passes:passes",
30+
"//executorch/exir/_serialize:lib",
31+
"//executorch/exir/backend:backend_details",
32+
"//executorch/exir/backend:compile_spec_schema",
33+
],
34+
)
35+
1936
# AOTI common shims functionality
2037
runtime.cxx_library(
2138
name = "common_shims",

0 commit comments

Comments
 (0)