77import  contextlib 
88import  os 
99import  typing 
10+ from  enum  import  Enum 
1011
1112from  typing  import  Any , Dict , final , List , Optional , Set 
1213
1314import  torch 
15+ from  executorch .backends .cuda .replace_slice_copy_with_slice  import  ReplaceSliceCopyWithSlicePass 
1416from  executorch .exir ._serialize ._named_data_store  import  NamedDataStore 
1517from  executorch .exir ._warnings  import  experimental 
1618from  executorch .exir .backend .backend_details  import  (
2123from  executorch .exir .backend .compile_spec_schema  import  CompileSpec 
2224from  torch ._inductor .codegen .cpp_wrapper_cpu  import  CppWrapperCpu 
2325from  torch .export .passes  import  move_to_device_pass 
24- 
26+ from   torch . nn . attention   import   SDPBackend 
2527
2628# exist fallback operators in et namespace; 
2729supported_fallback_kernels : Dict [str , Any ] =  {}
2830
2931# required fallback kernels but not supported 
3032missing_fallback_kernels : Set [str ] =  set ()
3133
34+ class  COMPILE_SPEC_KEYS (Enum ):
35+     METHOD_NAME  =  "method_name" 
3236
3337# context manager for non-fallback guarantee 
3438# it will raise exception when generating fallback kernels during aoti compile 
@@ -108,6 +112,9 @@ def preprocess(
108112        # Move the edge_program from CPU to CUDA for aoti compile 
109113        cuda_edge_program  =  move_to_device_pass (edge_program , "cuda" )
110114
115+         # replace slice_copy with slice 
116+         ReplaceSliceCopyWithSlicePass ()(cuda_edge_program .graph_module )
117+ 
111118        edge_program_module  =  cuda_edge_program .module ()
112119
113120        # Grab all input placeholders from the graph 
@@ -132,7 +139,8 @@ def preprocess(
132139            "max_autotune_conv_backends" : "TRITON" ,
133140        }
134141
135-         with  collect_unsupported_fallback_kernels ():
142+         with  collect_unsupported_fallback_kernels (), torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
143+             torch ._logging .set_logs (post_grad_graphs = True )
136144            so_path  =  torch ._inductor .aot_compile (edge_program_module , tuple (user_input_placeholders ), options = options )  # type: ignore[arg-type] 
137145            if  len (missing_fallback_kernels ) >  0 :
138146                formatted_kernels  =  "\n   - " .join (sorted (missing_fallback_kernels ))
@@ -146,7 +154,8 @@ def preprocess(
146154            so_data  =  f .read ()
147155
148156        named_data_store  =  NamedDataStore ()
149-         named_data_store .add_named_data ("so_blob" , so_data , 1 , "aoti_cuda_blob" )
157+         method_name  =  CudaBackend .method_name_from_compile_specs (compile_specs )
158+         named_data_store .add_named_data (method_name  +  "_so_blob" , so_data , 1 , "aoti_cuda_blob" )
150159
151160        # Clean up the generated so file; it has been packaged into the NamdeDataStore 
152161        # pyre-ignorep[6]: Incompatible parameter type 
@@ -157,3 +166,30 @@ def preprocess(
157166            debug_handle_map = {},
158167            data_store_output = named_data_store .get_named_data_store_output (),
159168        )
169+ 
170+     @staticmethod  
171+     def  generate_method_name_compile_spec (
172+         method_name : str ,
173+     ) ->  CompileSpec :
174+         """ 
175+         Returns the compile spec representing the model compute precision, for additional details 
176+         please refer to the documentation for ``coremltools.precision``. 
177+         """ 
178+         return  CompileSpec (
179+             COMPILE_SPEC_KEYS .METHOD_NAME .value ,
180+             method_name .encode ("utf-8" ),
181+         )
182+ 
183+     @staticmethod  
184+     def  method_name_from_compile_specs (
185+         compile_specs : List [CompileSpec ],
186+     ) ->  str :
187+         """ 
188+         Returns the method name from the compile specs. 
189+         """ 
190+         for  spec  in  compile_specs :
191+             if  spec .key  ==  COMPILE_SPEC_KEYS .METHOD_NAME .value :
192+                 return  spec .value .decode ("utf-8" )
193+         raise  RuntimeError (
194+             f"Could not find method name in compile specs: { compile_specs }  " 
195+         )
0 commit comments