Skip to content

Commit b1309e7

Browse files
authored
Aoti support multi method (pytorch#14715)
This pull request introduces several improvements to the CUDA backend. The main changes include adding a new graph pass to replace unnecessary `slice_copy` operations, improving how method names are tracked in compilation artifacts, and making the preprocessing pipeline more robust and accurate. **Key changes:** ### Graph optimization and preprocessing * Introduced `ReplaceSliceCopyWithSlicePass`, a new export pass that replaces non-mutated `slice_copy` operations with more efficient `slice` view operations in the computational graph (`replace_slice_copy_with_slice.py`, used in `cuda_backend.py`). [[1]](diffhunk://#diff-c4a228b182f50f778545991d472609ad705d2325994342174093ff374738851dR1-R113) [[2]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aR115-R117) * Added context management for attention kernel selection and no-grad mode during AOT compilation to ensure correct backend selection for decomposition. This is needed in the short term until we have a flash attention cuda kernel. ### Method name and compile specification handling * Added a `COMPILE_SPEC_KEYS` enum and utility methods (`generate_method_name_compile_spec`, `method_name_from_compile_specs`) to consistently embed and retrieve the method name in compile specs and as a key in the data store, improving traceability of compiled artifacts. [[1]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL24-R35) [[2]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aL161-R158) [[3]](diffhunk://#diff-5b5ea2257772b3aba04b2534f5ea1429a0c631bfd25a7ef531f526e76c471d7aR169-R195) ### Code cleanup and maintainability * Minor refactor in `cuda_partitioner.py` to clarify delegation tag assignment. * Improved imports and code organization for clarity in `cuda_backend.py`. These changes collectively improve the reliability, performance, and maintainability of the CUDA backend pipeline.
1 parent 96dfa9c commit b1309e7

File tree

3 files changed

+166
-5
lines changed

3 files changed

+166
-5
lines changed

backends/cuda/cuda_backend.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
import contextlib
88
import os
99
import typing
10+
from enum import Enum
1011

1112
from typing import Any, Dict, final, List, Optional, Set
1213

1314
import torch
15+
from executorch.backends.cuda.replace_slice_copy_with_slice import (
16+
ReplaceSliceCopyWithSlicePass,
17+
)
1418
from executorch.exir._serialize._named_data_store import NamedDataStore
1519
from executorch.exir._warnings import experimental
1620
from executorch.exir.backend.backend_details import (
@@ -21,7 +25,7 @@
2125
from executorch.exir.backend.compile_spec_schema import CompileSpec
2226
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2327
from torch.export.passes import move_to_device_pass
24-
28+
from torch.nn.attention import SDPBackend
2529

2630
# exist fallback operators in et namespace;
2731
supported_fallback_kernels: Dict[str, Any] = {}
@@ -30,6 +34,10 @@
3034
missing_fallback_kernels: Set[str] = set()
3135

3236

37+
class COMPILE_SPEC_KEYS(Enum):
38+
METHOD_NAME = "method_name"
39+
40+
3341
# context manager for non-fallback guarantee
3442
# it will raise exception when generating fallback kernels during aoti compile
3543
@contextlib.contextmanager
@@ -108,6 +116,9 @@ def preprocess(
108116
# Move the edge_program from CPU to CUDA for aoti compile
109117
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
110118

119+
# replace slice_copy with slice
120+
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
121+
111122
edge_program_module = cuda_edge_program.module()
112123

113124
# Grab all input placeholders from the graph
@@ -132,7 +143,10 @@ def preprocess(
132143
"max_autotune_conv_backends": "TRITON",
133144
}
134145

135-
with collect_unsupported_fallback_kernels():
146+
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
147+
[SDPBackend.MATH]
148+
), torch.no_grad():
149+
# torch._logging.set_logs(post_grad_graphs=True)
136150
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
137151
if len(missing_fallback_kernels) > 0:
138152
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
@@ -146,7 +160,10 @@ def preprocess(
146160
so_data = f.read()
147161

148162
named_data_store = NamedDataStore()
149-
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
163+
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)
164+
named_data_store.add_named_data(
165+
method_name + "_so_blob", so_data, 1, "aoti_cuda_blob"
166+
)
150167

151168
# Clean up the generated so file; it has been packaged into the NamdeDataStore
152169
# pyre-ignorep[6]: Incompatible parameter type
@@ -157,3 +174,30 @@ def preprocess(
157174
debug_handle_map={},
158175
data_store_output=named_data_store.get_named_data_store_output(),
159176
)
177+
178+
@staticmethod
179+
def generate_method_name_compile_spec(
180+
method_name: str,
181+
) -> CompileSpec:
182+
"""
183+
Returns the compile spec representing the model compute precision, for additional details
184+
please refer to the documentation for ``coremltools.precision``.
185+
"""
186+
return CompileSpec(
187+
COMPILE_SPEC_KEYS.METHOD_NAME.value,
188+
method_name.encode("utf-8"),
189+
)
190+
191+
@staticmethod
192+
def method_name_from_compile_specs(
193+
compile_specs: List[CompileSpec],
194+
) -> str:
195+
"""
196+
Returns the method name from the compile specs.
197+
"""
198+
for spec in compile_specs:
199+
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
200+
return spec.value.decode("utf-8")
201+
raise RuntimeError(
202+
f"Could not find method name in compile specs: {compile_specs}"
203+
)

backends/cuda/cuda_partitioner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
4444
"""
4545

4646
partition_tags: Dict[str, DelegationSpec] = {}
47+
tag = "tag0"
48+
4749
for node in exported_program.graph.nodes:
4850
if node.op != "call_function":
4951
continue
50-
tag = "tag0"
5152
node.meta["delegation_tag"] = tag
52-
partition_tags[tag] = self.delegation_spec
53+
54+
partition_tags[tag] = self.delegation_spec
5355

5456
tag_constant_data(exported_program)
5557

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Iterable
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch import fx
15+
16+
17+
_SLICE_COPY_TARGETS = (
18+
torch.ops.aten.slice_copy.Tensor,
19+
ops.edge.aten.slice_copy.Tensor,
20+
)
21+
22+
_SLICE_TARGETS = {
23+
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
24+
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
25+
}
26+
27+
28+
class ReplaceSliceCopyWithSlicePass(ExportPass):
29+
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
30+
31+
def call(self, graph_module: fx.GraphModule) -> PassResult:
32+
graph_changed = False
33+
34+
for node in graph_module.graph.nodes:
35+
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
36+
continue
37+
38+
if self._has_blocking_user(node, node.users.keys()):
39+
continue
40+
41+
node.target = _SLICE_TARGETS[node.target]
42+
graph_changed = True
43+
44+
if graph_changed:
45+
graph_module.graph.lint()
46+
graph_module.recompile()
47+
48+
return PassResult(graph_module, graph_changed)
49+
50+
def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool:
51+
for user in users:
52+
if self._is_mutating_user(node, user) or self._is_view_user(node, user):
53+
return True
54+
return False
55+
56+
def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool:
57+
if user.op == "call_method":
58+
# Treat in-place tensor methods conservatively as mutations only when the
59+
# method name ends with ``_`` which is the PyTorch convention for mutation.
60+
return isinstance(user.target, str) and user.target.endswith("_")
61+
62+
if user.op != "call_function":
63+
return False
64+
65+
target = user.target
66+
if not hasattr(target, "_schema"):
67+
return False
68+
69+
schema = target._schema # pyre-ignore[16]
70+
# Positional arguments
71+
for index, arg in enumerate(user.args):
72+
if arg is node and self._argument_mutates(schema, index):
73+
return True
74+
75+
# Keyword arguments
76+
for name, arg in user.kwargs.items():
77+
if arg is node and self._argument_mutates(schema, name):
78+
return True
79+
80+
return False
81+
82+
def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
83+
if user.op == "call_method":
84+
# Treat tensor methods conservatively and assume they may be view-producing.
85+
return True
86+
87+
if user.op != "call_function":
88+
return False
89+
90+
target = user.target
91+
if getattr(target, "is_view", False):
92+
for arg in user.args:
93+
if arg is node:
94+
return True
95+
for arg in user.kwargs.values():
96+
if arg is node:
97+
return True
98+
99+
return False
100+
101+
def _argument_mutates(
102+
self, schema: torch._C.FunctionSchema, key
103+
) -> bool: # pyre-ignore[11]
104+
arguments = schema.arguments
105+
if isinstance(key, int):
106+
if key >= len(arguments):
107+
return False
108+
argument = arguments[key]
109+
else:
110+
argument = next((arg for arg in arguments if arg.name == key), None)
111+
if argument is None:
112+
return False
113+
114+
alias_info = argument.alias_info
115+
return bool(alias_info and alias_info.is_write)

0 commit comments

Comments
 (0)