Skip to content

Commit 97a8762

Browse files
author
ssjia
committed
[ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache
SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops. However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things: 1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache 2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used. Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/) ghstack-source-id: 321176636 Pull Request resolved: #15618
1 parent 2bf55d8 commit 97a8762

File tree

3 files changed

+169
-0
lines changed

3 files changed

+169
-0
lines changed

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
1414
"quantized_binary.py",
15+
"sdpa.py",
1516
"select_as_symint.py",
1617
],
1718
visibility = [

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.rope # noqa
1616

17+
import executorch.backends.vulkan.patterns.sdpa # noqa
18+
1719
import executorch.backends.vulkan.patterns.select_as_symint # noqa
1820

1921
import torch

backends/vulkan/patterns/sdpa.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Optional
8+
9+
import torch
10+
11+
from executorch.backends.vulkan.patterns.pattern_registry import (
12+
PatternMatch,
13+
register_pattern_detector,
14+
register_pattern_replacement,
15+
)
16+
17+
from executorch.exir import ExportedProgram
18+
from executorch.exir.dialects._ops import ops as exir_ops
19+
20+
21+
def is_update_cache_node(node: Any) -> bool:
22+
if not hasattr(node, "target"):
23+
return False
24+
25+
if isinstance(node.target, str):
26+
return node.target == "llama::update_cache"
27+
elif hasattr(node.target, "name"):
28+
return node.target.name() == "llama::update_cache"
29+
else:
30+
return False
31+
32+
33+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
34+
if not hasattr(node, "target"):
35+
return False
36+
37+
if isinstance(node.target, str):
38+
return "sdpa_with_kv_cache" in node.target
39+
elif hasattr(node.target, "name"):
40+
return "sdpa_with_kv_cache" in node.target.name()
41+
else:
42+
return False
43+
44+
45+
class CausalSDPAMatch(PatternMatch):
46+
def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
47+
self.anchor_node = custom_sdpa_node
48+
self.match_found = False
49+
self.all_nodes = [self.anchor_node]
50+
51+
# llama.custom_sdpa has signature:
52+
# custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output
53+
if len(custom_sdpa_node.args) < 4:
54+
return
55+
56+
self.query_node = custom_sdpa_node.args[0]
57+
self.key_cache_node = custom_sdpa_node.args[1]
58+
self.value_cache_node = custom_sdpa_node.args[2]
59+
self.start_pos_node = custom_sdpa_node.args[3]
60+
self.attn_mask_node = custom_sdpa_node.args[4]
61+
self.dropout_p_node = custom_sdpa_node.args[5]
62+
self.is_causal_node = custom_sdpa_node.args[6]
63+
if len(custom_sdpa_node.args) > 7:
64+
self.scale_node = custom_sdpa_node.args[7]
65+
else:
66+
self.scale_node = None
67+
68+
# try to find update key cache node
69+
self.update_key_cache_node = None
70+
for user in self.key_cache_node.users:
71+
if is_update_cache_node(user):
72+
self.update_key_cache_node = user
73+
break
74+
75+
self.key_projection_node = None
76+
if self.update_key_cache_node is not None:
77+
self.key_projection_node = self.update_key_cache_node.args[0]
78+
79+
# find update value cache node
80+
self.update_value_cache_node = None
81+
for user in self.value_cache_node.users:
82+
if is_update_cache_node(user):
83+
self.update_value_cache_node = user
84+
break
85+
86+
self.value_projection_node = None
87+
if self.update_value_cache_node is not None:
88+
self.value_projection_node = self.update_value_cache_node.args[0]
89+
90+
# We have additional optional arguments but we don't need to capture them
91+
# since the new op doesn't use them
92+
93+
self.match_found = True
94+
95+
96+
@register_pattern_detector("causal_sdpa")
97+
def find_causal_sdpa_patterns(
98+
node: torch.fx.Node,
99+
) -> Optional[CausalSDPAMatch]:
100+
if node.target != exir_ops.edge.llama.custom_sdpa.default:
101+
return None
102+
103+
matched_pattern = CausalSDPAMatch(node)
104+
if matched_pattern.match_found:
105+
return matched_pattern
106+
107+
return None
108+
109+
110+
##
111+
## Pattern Replacement
112+
##
113+
114+
115+
def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule):
116+
for node in graph_module.graph.nodes:
117+
if is_update_cache_node(node):
118+
return node.args[2]
119+
120+
if is_sdpa_with_kv_cache_node(node):
121+
return node.args[5]
122+
123+
raise Exception(
124+
"Could not find an instance of llama::update_cache or sdpa_with_kv_cache"
125+
)
126+
127+
128+
@register_pattern_replacement("causal_sdpa")
129+
def replace_custom_sdpa_with_causal_sdpa(
130+
ep: ExportedProgram,
131+
graph_module: torch.fx.GraphModule,
132+
match: CausalSDPAMatch,
133+
):
134+
assert match.update_key_cache_node is not None
135+
assert match.key_projection_node is not None
136+
assert match.update_value_cache_node is not None
137+
assert match.value_projection_node is not None
138+
139+
singleton_start_pos_node = find_singleton_start_pos_node(graph_module)
140+
141+
with graph_module.graph.inserting_before(match.anchor_node):
142+
new_node = graph_module.graph.create_node(
143+
"call_function",
144+
torch.ops.llama.sdpa_with_kv_cache.default,
145+
args=(
146+
match.query_node,
147+
match.key_projection_node,
148+
match.value_projection_node,
149+
match.key_cache_node,
150+
match.value_cache_node,
151+
singleton_start_pos_node,
152+
1,
153+
match.attn_mask_node,
154+
match.dropout_p_node,
155+
match.is_causal_node,
156+
match.scale_node,
157+
),
158+
)
159+
160+
new_node.meta["val"] = match.anchor_node.meta["val"]
161+
match.anchor_node.replace_all_uses_with(new_node)
162+
163+
# Manually erase update_cache nodes since DCE will not remove them since they
164+
# modify inputs (specifically, the cache args are modified)
165+
graph_module.graph.erase_node(match.update_key_cache_node)
166+
graph_module.graph.erase_node(match.update_value_cache_node)

0 commit comments

Comments
 (0)