Skip to content

Commit 8c3ec9e

Browse files
committed
Make it work
1 parent 96dfa9c commit 8c3ec9e

File tree

3 files changed

+156
-5
lines changed

3 files changed

+156
-5
lines changed

backends/cuda/cuda_backend.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
import contextlib
88
import os
99
import typing
10+
from enum import Enum
1011

1112
from typing import Any, Dict, final, List, Optional, Set
1213

1314
import torch
15+
from executorch.backends.cuda.replace_slice_copy_with_slice import ReplaceSliceCopyWithSlicePass
1416
from executorch.exir._serialize._named_data_store import NamedDataStore
1517
from executorch.exir._warnings import experimental
1618
from executorch.exir.backend.backend_details import (
@@ -21,14 +23,16 @@
2123
from executorch.exir.backend.compile_spec_schema import CompileSpec
2224
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
2325
from torch.export.passes import move_to_device_pass
24-
26+
from torch.nn.attention import SDPBackend
2527

2628
# exist fallback operators in et namespace;
2729
supported_fallback_kernels: Dict[str, Any] = {}
2830

2931
# required fallback kernels but not supported
3032
missing_fallback_kernels: Set[str] = set()
3133

34+
class COMPILE_SPEC_KEYS(Enum):
35+
METHOD_NAME = "method_name"
3236

3337
# context manager for non-fallback guarantee
3438
# it will raise exception when generating fallback kernels during aoti compile
@@ -108,6 +112,9 @@ def preprocess(
108112
# Move the edge_program from CPU to CUDA for aoti compile
109113
cuda_edge_program = move_to_device_pass(edge_program, "cuda")
110114

115+
# replace slice_copy with slice
116+
ReplaceSliceCopyWithSlicePass()(cuda_edge_program.graph_module)
117+
111118
edge_program_module = cuda_edge_program.module()
112119

113120
# Grab all input placeholders from the graph
@@ -132,7 +139,8 @@ def preprocess(
132139
"max_autotune_conv_backends": "TRITON",
133140
}
134141

135-
with collect_unsupported_fallback_kernels():
142+
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
143+
torch._logging.set_logs(post_grad_graphs=True)
136144
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]
137145
if len(missing_fallback_kernels) > 0:
138146
formatted_kernels = "\n - ".join(sorted(missing_fallback_kernels))
@@ -146,7 +154,8 @@ def preprocess(
146154
so_data = f.read()
147155

148156
named_data_store = NamedDataStore()
149-
named_data_store.add_named_data("so_blob", so_data, 1, "aoti_cuda_blob")
157+
method_name = CudaBackend.method_name_from_compile_specs(compile_specs)
158+
named_data_store.add_named_data(method_name + "_so_blob", so_data, 1, "aoti_cuda_blob")
150159

151160
# Clean up the generated so file; it has been packaged into the NamdeDataStore
152161
# pyre-ignorep[6]: Incompatible parameter type
@@ -157,3 +166,30 @@ def preprocess(
157166
debug_handle_map={},
158167
data_store_output=named_data_store.get_named_data_store_output(),
159168
)
169+
170+
@staticmethod
171+
def generate_method_name_compile_spec(
172+
method_name: str,
173+
) -> CompileSpec:
174+
"""
175+
Returns the compile spec representing the model compute precision, for additional details
176+
please refer to the documentation for ``coremltools.precision``.
177+
"""
178+
return CompileSpec(
179+
COMPILE_SPEC_KEYS.METHOD_NAME.value,
180+
method_name.encode("utf-8"),
181+
)
182+
183+
@staticmethod
184+
def method_name_from_compile_specs(
185+
compile_specs: List[CompileSpec],
186+
) -> str:
187+
"""
188+
Returns the method name from the compile specs.
189+
"""
190+
for spec in compile_specs:
191+
if spec.key == COMPILE_SPEC_KEYS.METHOD_NAME.value:
192+
return spec.value.decode("utf-8")
193+
raise RuntimeError(
194+
f"Could not find method name in compile specs: {compile_specs}"
195+
)

backends/cuda/cuda_partitioner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,14 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
4444
"""
4545

4646
partition_tags: Dict[str, DelegationSpec] = {}
47+
tag = "tag0"
48+
4749
for node in exported_program.graph.nodes:
4850
if node.op != "call_function":
4951
continue
50-
tag = "tag0"
5152
node.meta["delegation_tag"] = tag
52-
partition_tags[tag] = self.delegation_spec
53+
54+
partition_tags[tag] = self.delegation_spec
5355

5456
tag_constant_data(exported_program)
5557

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Iterable
10+
11+
import torch
12+
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.pass_base import ExportPass, PassResult
14+
from torch import fx
15+
16+
17+
_SLICE_COPY_TARGETS = (
18+
torch.ops.aten.slice_copy.Tensor,
19+
ops.edge.aten.slice_copy.Tensor,
20+
)
21+
22+
_SLICE_TARGETS = {
23+
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
24+
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
25+
}
26+
27+
28+
class ReplaceSliceCopyWithSlicePass(ExportPass):
29+
"""Replace non-mutated ``slice_copy`` results with ``slice`` views."""
30+
31+
def call(self, graph_module: fx.GraphModule) -> PassResult:
32+
graph_changed = False
33+
34+
for node in graph_module.graph.nodes:
35+
if node.op != "call_function" or node.target not in _SLICE_COPY_TARGETS:
36+
continue
37+
38+
if self._has_blocking_user(node, node.users.keys()):
39+
continue
40+
41+
node.target = _SLICE_TARGETS[node.target]
42+
graph_changed = True
43+
44+
if graph_changed:
45+
graph_module.graph.lint()
46+
graph_module.recompile()
47+
48+
return PassResult(graph_module, graph_changed)
49+
50+
def _has_blocking_user(self, node: fx.Node, users: Iterable[fx.Node]) -> bool:
51+
for user in users:
52+
if self._is_mutating_user(node, user) or self._is_view_user(node, user):
53+
return True
54+
return False
55+
56+
def _is_mutating_user(self, node: fx.Node, user: fx.Node) -> bool:
57+
if user.op == "call_method":
58+
# Treat in-place tensor methods conservatively as mutations only when the
59+
# method name ends with ``_`` which is the PyTorch convention for mutation.
60+
return isinstance(user.target, str) and user.target.endswith("_")
61+
62+
if user.op != "call_function":
63+
return False
64+
65+
target = user.target
66+
if not hasattr(target, "_schema"):
67+
return False
68+
69+
schema = target._schema # pyre-ignore[16]
70+
# Positional arguments
71+
for index, arg in enumerate(user.args):
72+
if arg is node and self._argument_mutates(schema, index):
73+
return True
74+
75+
# Keyword arguments
76+
for name, arg in user.kwargs.items():
77+
if arg is node and self._argument_mutates(schema, name):
78+
return True
79+
80+
return False
81+
82+
def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
83+
if user.op == "call_method":
84+
# Treat tensor methods conservatively and assume they may be view-producing.
85+
return True
86+
87+
if user.op != "call_function":
88+
return False
89+
90+
target = user.target
91+
if getattr(target, "is_view", False):
92+
for arg in user.args:
93+
if arg is node:
94+
return True
95+
for arg in user.kwargs.values():
96+
if arg is node:
97+
return True
98+
99+
return False
100+
101+
def _argument_mutates(self, schema: torch._C.FunctionSchema, key) -> bool: # pyre-ignore[11]
102+
arguments = schema.arguments
103+
if isinstance(key, int):
104+
if key >= len(arguments):
105+
return False
106+
argument = arguments[key]
107+
else:
108+
argument = next((arg for arg in arguments if arg.name == key), None)
109+
if argument is None:
110+
return False
111+
112+
alias_info = argument.alias_info
113+
return bool(alias_info and alias_info.is_write)

0 commit comments

Comments
 (0)