1212from typing import Any , Dict , final , List , Optional , Set
1313
1414import torch
15- from executorch .backends .cuda .replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass
15+ from executorch .backends .cuda .replace_slice_copy_with_slice import (
16+ ReplaceSliceCopyWithSlicePass ,
17+ )
1618from executorch .exir ._serialize ._named_data_store import NamedDataStore
1719from executorch .exir ._warnings import experimental
1820from executorch .exir .backend .backend_details import (
3133# required fallback kernels but not supported
3234missing_fallback_kernels : Set [str ] = set ()
3335
36+
3437class COMPILE_SPEC_KEYS (Enum ):
3538 METHOD_NAME = "method_name"
3639
40+
3741# context manager for non-fallback guarantee
3842# it will raise exception when generating fallback kernels during aoti compile
3943@contextlib .contextmanager
@@ -139,8 +143,10 @@ def preprocess(
139143 "max_autotune_conv_backends" : "TRITON" ,
140144 }
141145
142- with collect_unsupported_fallback_kernels (), torch .nn .attention .sdpa_kernel ([SDPBackend .MATH ]), torch .no_grad ():
143- torch ._logging .set_logs (post_grad_graphs = True )
146+ with collect_unsupported_fallback_kernels (), torch .nn .attention .sdpa_kernel (
147+ [SDPBackend .MATH ]
148+ ), torch .no_grad ():
149+ # torch._logging.set_logs(post_grad_graphs=True)
144150 so_path = torch ._inductor .aot_compile (edge_program_module , tuple (user_input_placeholders ), options = options ) # type: ignore[arg-type]
145151 if len (missing_fallback_kernels ) > 0 :
146152 formatted_kernels = "\n - " .join (sorted (missing_fallback_kernels ))
@@ -155,7 +161,9 @@ def preprocess(
155161
156162 named_data_store = NamedDataStore ()
157163 method_name = CudaBackend .method_name_from_compile_specs (compile_specs )
158- named_data_store .add_named_data (method_name + "_so_blob" , so_data , 1 , "aoti_cuda_blob" )
164+ named_data_store .add_named_data (
165+ method_name + "_so_blob" , so_data , 1 , "aoti_cuda_blob"
166+ )
159167
160168 # Clean up the generated so file; it has been packaged into the NamdeDataStore
161169 # pyre-ignorep[6]: Incompatible parameter type
0 commit comments