Skip to content

Commit e570373

Browse files
authored
[ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache
Differential Revision: D86340339 Pull Request resolved: #15618
1 parent 1b54822 commit e570373

File tree

4 files changed

+170
-0
lines changed

4 files changed

+170
-0
lines changed

backends/vulkan/patterns/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ runtime.python_library(
1212
"quantized_linear.py",
1313
"quantized_convolution.py",
1414
"quantized_binary.py",
15+
"sdpa.py",
1516
"select_as_symint.py",
1617
],
1718
visibility = [

backends/vulkan/patterns/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import executorch.backends.vulkan.patterns.rope # noqa
1616

17+
import executorch.backends.vulkan.patterns.sdpa # noqa
18+
1719
import executorch.backends.vulkan.patterns.select_as_symint # noqa
1820

1921
import torch

backends/vulkan/patterns/sdpa.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any, Optional
8+
9+
import executorch.backends.vulkan.utils as utils
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
PatternMatch,
15+
register_pattern_detector,
16+
register_pattern_replacement,
17+
)
18+
19+
from executorch.exir import ExportedProgram
20+
21+
22+
def is_update_cache_node(node: Any) -> bool:
23+
return utils.node_has_target(node, "llama::update_cache")
24+
25+
26+
def is_custom_sdpa_node(node: Any) -> bool:
27+
return utils.node_has_target(node, "llama::custom_sdpa")
28+
29+
30+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
31+
return utils.node_has_target(node, "llama::sdpa_with_kv_cache")
32+
33+
34+
class CausalSDPAMatch(PatternMatch):
35+
def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
36+
self.anchor_node = custom_sdpa_node
37+
self.match_found = False
38+
self.all_nodes = [self.anchor_node]
39+
40+
# llama.custom_sdpa has signature:
41+
# custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output
42+
if len(custom_sdpa_node.args) < 4:
43+
return
44+
45+
self.query_node = custom_sdpa_node.args[0]
46+
self.key_cache_node = custom_sdpa_node.args[1]
47+
self.value_cache_node = custom_sdpa_node.args[2]
48+
self.start_pos_node = custom_sdpa_node.args[3]
49+
self.attn_mask_node = custom_sdpa_node.args[4]
50+
self.dropout_p_node = custom_sdpa_node.args[5]
51+
self.is_causal_node = custom_sdpa_node.args[6]
52+
if len(custom_sdpa_node.args) > 7:
53+
self.scale_node = custom_sdpa_node.args[7]
54+
else:
55+
self.scale_node = None
56+
57+
# try to find update key cache node
58+
self.update_key_cache_node = None
59+
for user in self.key_cache_node.users:
60+
if is_update_cache_node(user):
61+
self.update_key_cache_node = user
62+
break
63+
64+
self.key_projection_node = None
65+
if self.update_key_cache_node is not None:
66+
self.key_projection_node = self.update_key_cache_node.args[0]
67+
68+
# find update value cache node
69+
self.update_value_cache_node = None
70+
for user in self.value_cache_node.users:
71+
if is_update_cache_node(user):
72+
self.update_value_cache_node = user
73+
break
74+
75+
self.value_projection_node = None
76+
if self.update_value_cache_node is not None:
77+
self.value_projection_node = self.update_value_cache_node.args[0]
78+
79+
# We have additional optional arguments but we don't need to capture them
80+
# since the new op doesn't use them
81+
82+
self.match_found = True
83+
84+
85+
@register_pattern_detector("causal_sdpa")
86+
def find_causal_sdpa_patterns(
87+
node: torch.fx.Node,
88+
) -> Optional[CausalSDPAMatch]:
89+
if not is_custom_sdpa_node(node):
90+
return None
91+
92+
matched_pattern = CausalSDPAMatch(node)
93+
if matched_pattern.match_found:
94+
return matched_pattern
95+
96+
return None
97+
98+
99+
##
100+
## Pattern Replacement
101+
##
102+
103+
104+
def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule):
105+
for node in graph_module.graph.nodes:
106+
if is_update_cache_node(node):
107+
return node.args[2]
108+
109+
if is_sdpa_with_kv_cache_node(node):
110+
return node.args[5]
111+
112+
raise Exception(
113+
"Could not find an instance of llama::update_cache or sdpa_with_kv_cache"
114+
)
115+
116+
117+
@register_pattern_replacement("causal_sdpa")
118+
def replace_custom_sdpa_with_causal_sdpa(
119+
ep: ExportedProgram,
120+
graph_module: torch.fx.GraphModule,
121+
match: CausalSDPAMatch,
122+
):
123+
assert match.update_key_cache_node is not None
124+
assert match.key_projection_node is not None
125+
assert match.update_value_cache_node is not None
126+
assert match.value_projection_node is not None
127+
128+
singleton_start_pos_node = find_singleton_start_pos_node(graph_module)
129+
130+
with graph_module.graph.inserting_before(match.anchor_node):
131+
new_node = graph_module.graph.create_node(
132+
"call_function",
133+
torch.ops.llama.sdpa_with_kv_cache.default,
134+
args=(
135+
match.query_node,
136+
match.key_projection_node,
137+
match.value_projection_node,
138+
match.key_cache_node,
139+
match.value_cache_node,
140+
singleton_start_pos_node,
141+
1,
142+
match.attn_mask_node,
143+
match.dropout_p_node,
144+
match.is_causal_node,
145+
match.scale_node,
146+
),
147+
)
148+
149+
new_node.meta["val"] = match.anchor_node.meta["val"]
150+
match.anchor_node.replace_all_uses_with(new_node)
151+
152+
# Manually erase update_cache nodes since DCE will not remove them since they
153+
# modify inputs (specifically, the cache args are modified)
154+
graph_module.graph.erase_node(match.update_key_cache_node)
155+
graph_module.graph.erase_node(match.update_value_cache_node)

backends/vulkan/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
373373
return None
374374

375375

376+
def node_has_target(node: Any, target: str):
377+
if not hasattr(node, "target"):
378+
return False
379+
380+
if isinstance(node.target, str):
381+
return node.target == target
382+
elif hasattr(node.target, "name"):
383+
return node.target.name() == target
384+
385+
return False
386+
387+
376388
##
377389
## Memory Layout, Storage Type Determination
378390
##

0 commit comments

Comments
 (0)