Skip to content

Commit 07c6c87

Browse files
committed
add support fallback kernels check
1 parent 1a4f448 commit 07c6c87

File tree

6 files changed

+220
-171
lines changed

6 files changed

+220
-171
lines changed

backends/aoti/aoti_backend.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import copy
89
import os
910
import shutil
1011
import typing
1112

1213
from subprocess import check_call
13-
from typing import final, List
14+
from typing import Any, Dict, final, List, Optional, Set
1415

1516
import torch
1617
from executorch.exir.backend.backend_details import (
@@ -19,6 +20,48 @@
1920
PreprocessResult,
2021
)
2122
from executorch.exir.backend.compile_spec_schema import CompileSpec
23+
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
24+
25+
26+
# exist fallback operators in et namespace;
27+
supported_fallback_kernels: Dict[str, Any] = {}
28+
29+
# required fallback kernels but not supported
30+
missing_fallback_kernels: Set[str] = set()
31+
32+
33+
# context manager for non-fallback guarantee
34+
# it will raise exception when generating fallback kernels during aoti compile
35+
@contextlib.contextmanager
36+
def raise_on_generate_fall_back_call():
37+
original_generate_c_shim_extern_kernel_call = (
38+
CppWrapperCpu.generate_c_shim_extern_kernel_call
39+
)
40+
41+
def generate_supported_c_shim_extern_kernel_call(
42+
self,
43+
kernel: str,
44+
args: list[str],
45+
device: str,
46+
*,
47+
debug_args: Optional[list[str]] = None,
48+
):
49+
if kernel in supported_fallback_kernels:
50+
original_generate_c_shim_extern_kernel_call(
51+
self, kernel, args, device, debug_args=debug_args
52+
)
53+
else:
54+
missing_fallback_kernels.add(kernel)
55+
56+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
57+
generate_supported_c_shim_extern_kernel_call
58+
)
59+
try:
60+
yield
61+
finally:
62+
CppWrapperCpu.generate_c_shim_extern_kernel_call = (
63+
original_generate_c_shim_extern_kernel_call
64+
)
2265

2366

2467
@final
@@ -50,7 +93,14 @@ def preprocess(
5093
"max_autotune_conv_backends": "TRITON",
5194
}
5295

53-
so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type]
96+
with raise_on_generate_fall_back_call():
97+
so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type]
98+
if len(missing_fallback_kernels) > 0:
99+
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
100+
raise RuntimeError(
101+
f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n"
102+
"Please add them to the AOTI backend."
103+
)
54104

55105
assert so_path == output_path, f"Expected {output_path} but got {so_path}"
56106

0 commit comments

Comments
 (0)