Skip to content

Commit 5a40be7

Browse files
committed
Address comments
1 parent 8c3ec9e commit 5a40be7

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

backends/cuda/cuda_backend.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
from typing import Any, Dict, final, List, Optional, Set
1313

1414
import torch
15-
from executorch.backends.cuda.replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass
15+
from executorch.backends.cuda.replace_slice_copy_with_slice import (
16+
ReplaceSliceCopyWithSlicePass,
17+
)
1618
from executorch.exir._serialize._named_data_store import NamedDataStore
1719
from executorch.exir._warnings import experimental
1820
from executorch.exir.backend.backend_details import (
@@ -31,9 +33,11 @@
3133
# required fallback kernels but not supported
3234
missing_fallback_kernels: Set[str] = set()
3335

36+
3437
class COMPILE_SPEC_KEYS(Enum):
3538
METHOD_NAME = "method_name"
3639

40+
3741
# context manager for non-fallback guarantee
3842
# it will raise exception when generating fallback kernels during aoti compile
3943
@contextlib.contextmanager
@@ -139,8 +143,10 @@ def preprocess(
139143
"max_autotune_conv_backends": "TRITON",
140144
}
141145

142-
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
143-
torch._logging.set_logs(post_grad_graphs=True)
146+
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
147+
[SDPBackend.MATH]
148+
), torch.no_grad():
149+
# torch._logging.set_logs(post_grad_graphs=True)
144150
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
145151
if len(missing_fallback_kernels) > 0:
146152
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
@@ -155,7 +161,9 @@ def preprocess(
155161

156162
named_data_store = NamedDataStore()
157163
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)
158-
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, "aoti_cuda_blob")
164+
named_data_store.add_named_data(
165+
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
166+
)
159167

160168
# Clean up the generated so file; it has been packaged into the NamdeDataStore
161169
# pyre-ignorep[6]: Incompatible parameter type

backends/cuda/replace_slice_copy_with_slice.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
9898

9999
return False
100100

101-
def _argument_mutates(self, schema: torch._C.FunctionSchema, key) -> bool: # pyre-ignore[11]
101+
def _argument_mutates(
102+
self, schema: torch._C.FunctionSchema, key
103+
) -> bool: # pyre-ignore[11]
102104
arguments = schema.arguments
103105
if isinstance(key, int):
104106
if key >= len(arguments):

0 commit comments

Comments
 (0)