|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import contextlib |
7 | 8 | import copy |
8 | 9 | import os |
9 | 10 | import shutil |
10 | 11 | import typing |
11 | 12 |
|
12 | 13 | from subprocess import check_call |
13 | | -from typing import final, List |
| 14 | +from typing import Any, Dict, final, List, Optional, Set |
14 | 15 |
|
15 | 16 | import torch |
16 | 17 | from executorch.exir.backend.backend_details import ( |
|
19 | 20 | PreprocessResult, |
20 | 21 | ) |
21 | 22 | from executorch.exir.backend.compile_spec_schema import CompileSpec |
| 23 | +from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu |
| 24 | + |
| 25 | + |
| 26 | +# exist fallback operators in et namespace; |
| 27 | +supported_fallback_kernels: Dict[str, Any] = {} |
| 28 | + |
| 29 | +# required fallback kernels but not supported |
| 30 | +missing_fallback_kernels: Set[str] = set() |
| 31 | + |
| 32 | + |
| 33 | +# context manager for non-fallback guarantee |
| 34 | +# it will raise exception when generating fallback kernels during aoti compile |
| 35 | +@contextlib.contextmanager |
| 36 | +def raise_on_generate_fall_back_call(): |
| 37 | + original_generate_c_shim_extern_kernel_call = ( |
| 38 | + CppWrapperCpu.generate_c_shim_extern_kernel_call |
| 39 | + ) |
| 40 | + |
| 41 | + def generate_supported_c_shim_extern_kernel_call( |
| 42 | + self, |
| 43 | + kernel: str, |
| 44 | + args: list[str], |
| 45 | + device: str, |
| 46 | + *, |
| 47 | + debug_args: Optional[list[str]] = None, |
| 48 | + ): |
| 49 | + if kernel in supported_fallback_kernels: |
| 50 | + original_generate_c_shim_extern_kernel_call( |
| 51 | + self, kernel, args, device, debug_args=debug_args |
| 52 | + ) |
| 53 | + else: |
| 54 | + missing_fallback_kernels.add(kernel) |
| 55 | + |
| 56 | + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( |
| 57 | + generate_supported_c_shim_extern_kernel_call |
| 58 | + ) |
| 59 | + try: |
| 60 | + yield |
| 61 | + finally: |
| 62 | + CppWrapperCpu.generate_c_shim_extern_kernel_call = ( |
| 63 | + original_generate_c_shim_extern_kernel_call |
| 64 | + ) |
22 | 65 |
|
23 | 66 |
|
24 | 67 | @final |
@@ -50,7 +93,14 @@ def preprocess( |
50 | 93 | "max_autotune_conv_backends": "TRITON", |
51 | 94 | } |
52 | 95 |
|
53 | | - so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type] |
| 96 | + with raise_on_generate_fall_back_call(): |
| 97 | + so_path = torch._inductor.aot_compile(edge_program_module, args, kwargs, options=options) # type: ignore[arg-type] |
| 98 | + if len(missing_fallback_kernels) > 0: |
| 99 | + formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels)) |
| 100 | + raise RuntimeError( |
| 101 | + f"Missing fallback kernels ({len(missing_fallback_kernels)} total):\n - {formatted_kernels}\n" |
| 102 | + "Please add them to the AOTI backend." |
| 103 | + ) |
54 | 104 |
|
55 | 105 | assert so_path == output_path, f"Expected {output_path} but got {so_path}" |
56 | 106 |
|
|
0 commit comments