77import contextlib
88import os
99import typing
10+ from enum import Enum
1011
1112from typing import Any , Dict , final , List , Optional , Set
1213
1314import torch
15+ from executorch .backends .cuda .replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass
1416from executorch .exir ._serialize ._named_data_store import NamedDataStore
1517from executorch .exir ._warnings import experimental
1618from executorch .exir .backend .backend_details import (
2123from executorch .exir .backend .compile_spec_schema import CompileSpec
2224from torch ._inductor .codegen .cpp_wrapper_cpu import CppWrapperCpu
2325from torch .export .passes import move_to_device_pass
24-
26+ from torch . nn . attention import SDPBackend
2527
2628# exist fallback operators in et namespace;
2729supported_fallback_kernels : Dict [str , Any ] = {}
2830
2931# required fallback kernels but not supported
3032missing_fallback_kernels : Set [str ] = set ()
3133
34+ class COMPILE_SPEC_KEYS (Enum ):
35+ METHOD_NAME = "method_name"
3236
3337# context manager for non-fallback guarantee
3438# it will raise exception when generating fallback kernels during aoti compile
@@ -108,6 +112,9 @@ def preprocess(
108112 # Move the edge_program from CPU to CUDA for aoti compile
109113 cuda_edge_program = move_to_device_pass (edge_program , "cuda" )
110114
115+ # replace slice_copy with slice
116+ ReplaceSliceCopyWithSlicePass ()(cuda_edge_program .graph_module )
117+
111118 edge_program_module = cuda_edge_program .module ()
112119
113120 # Grab all input placeholders from the graph
@@ -132,7 +139,8 @@ def preprocess(
132139 "max_autotune_conv_backends" : "TRITON" ,
133140 }
134141
135- with collect_unsupported_fallback_kernels ():
142+ with collect_unsupported_fallback_kernels (), torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
143+ torch ._logging .set_logs (post_grad_graphs = True )
136144 so_path = torch ._inductor .aot_compile (edge_program_module , tuple (user_input_placeholders ), options = options ) # type: ignore[arg-type]
137145 if len (missing_fallback_kernels ) > 0 :
138146 formatted_kernels = "\n - " .join (sorted (missing_fallback_kernels ))
@@ -146,7 +154,8 @@ def preprocess(
146154 so_data = f .read ()
147155
148156 named_data_store = NamedDataStore ()
149- named_data_store .add_named_data ("so_blob" , so_data , 1 , "aoti_cuda_blob" )
157+ method_name = CudaBackend .method_name_from_compile_specs (compile_specs )
158+ named_data_store .add_named_data (method_name + "_so_blob" , so_data , 1 , "aoti_cuda_blob" )
150159
151160 # Clean up the generated so file; it has been packaged into the NamdeDataStore
152161 # pyre-ignorep[6]: Incompatible parameter type
@@ -157,3 +166,30 @@ def preprocess(
157166 debug_handle_map = {},
158167 data_store_output = named_data_store .get_named_data_store_output (),
159168 )
169+
170+ @staticmethod
171+ def generate_method_name_compile_spec (
172+ method_name : str ,
173+ ) -> CompileSpec :
174+ """
175+ Returns the compile spec representing the model compute precision, for additional details
176+ please refer to the documentation for ``coremltools.precision``.
177+ """
178+ return CompileSpec (
179+ COMPILE_SPEC_KEYS .METHOD_NAME .value ,
180+ method_name .encode ("utf-8" ),
181+ )
182+
183+ @staticmethod
184+ def method_name_from_compile_specs (
185+ compile_specs : List [CompileSpec ],
186+ ) -> str :
187+ """
188+ Returns the method name from the compile specs.
189+ """
190+ for spec in compile_specs :
191+ if spec .key == COMPILE_SPEC_KEYS .METHOD_NAME .value :
192+ return spec .value .decode ("utf-8" )
193+ raise RuntimeError (
194+ f"Could not find method name in compile specs: { compile_specs } "
195+ )
0 commit comments