Skip to content

Commit 9bdd6dc

Browse files
author
ssjia
committed
Update on "[ET-VK][ez] Fuse update_cache + custom_sdpa into sdpa_with_kv_cache"
SDPA used to be handled by a custom op `sdpa_with_kv_cache`, but it was eventually split (D62301837) into update_cache and custom_sdpa ops. However, having a single fused op is useful for Vulkan since it allows more control over how the cache tensors are stored and represented. Essentially, it makes it easier to manage the cache tensors and opens up opportunities for future optimizations. This diff introduces a fusion pass that does 2 things: 1. Combine update_cache and custom_sdpa back into sdpa_with_kv_cache 2. Ensure all references to the cache_pos symint use the same node - this prevents the select_at_dim_as_symint op from being called every time it is used. Differential Revision: [D86340339](https://our.internmc.facebook.com/intern/diff/D86340339/) [ghstack-poisoned]
2 parents e5495a8 + 1b54822 commit 9bdd6dc

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

backends/vulkan/patterns/sdpa.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Any, Optional
88

9+
import executorch.backends.vulkan.utils as utils
10+
911
import torch
1012

1113
from executorch.backends.vulkan.patterns.pattern_registry import (
@@ -15,31 +17,18 @@
1517
)
1618

1719
from executorch.exir import ExportedProgram
18-
from executorch.exir.dialects._ops import ops as exir_ops
1920

2021

2122
def is_update_cache_node(node: Any) -> bool:
22-
if not hasattr(node, "target"):
23-
return False
23+
return utils.node_has_target(node, "llama::update_cache")
2424

25-
if isinstance(node.target, str):
26-
return node.target == "llama::update_cache"
27-
elif hasattr(node.target, "name"):
28-
return node.target.name() == "llama::update_cache"
29-
else:
30-
return False
3125

26+
def is_custom_sdpa_node(node: Any) -> bool:
27+
return utils.node_has_target(node, "llama::custom_sdpa")
3228

33-
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
34-
if not hasattr(node, "target"):
35-
return False
3629

37-
if isinstance(node.target, str):
38-
return "sdpa_with_kv_cache" in node.target
39-
elif hasattr(node.target, "name"):
40-
return "sdpa_with_kv_cache" in node.target.name()
41-
else:
42-
return False
30+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
31+
return utils.node_has_target(node, "llama::sdpa_with_kv_cache")
4332

4433

4534
class CausalSDPAMatch(PatternMatch):
@@ -97,7 +86,7 @@ def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
9786
def find_causal_sdpa_patterns(
9887
node: torch.fx.Node,
9988
) -> Optional[CausalSDPAMatch]:
100-
if node.target != exir_ops.edge.llama.custom_sdpa.default:
89+
if not is_custom_sdpa_node(node):
10190
return None
10291

10392
matched_pattern = CausalSDPAMatch(node)

backends/vulkan/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
373373
return None
374374

375375

376+
def node_has_target(node: Any, target: str):
377+
if not hasattr(node, "target"):
378+
return False
379+
380+
if isinstance(node.target, str):
381+
return node.target == target
382+
elif hasattr(node.target, "name"):
383+
return node.target.name() == target
384+
385+
return False
386+
387+
376388
##
377389
## Memory Layout, Storage Type Determination
378390
##

0 commit comments

Comments
 (0)