Skip to content

Commit 31b4610

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] buffer implementation of rotary positional embeddings"
Title says it all! Differential Revision: [D86340338](https://our.internmc.facebook.com/intern/diff/D86340338/) [ghstack-poisoned]
1 parent 4685fe8 commit 31b4610

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

backends/vulkan/patterns/sdpa.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Any, Optional
88

9+
import executorch.backends.vulkan.utils as utils
10+
911
import torch
1012

1113
from executorch.backends.vulkan.patterns.pattern_registry import (
@@ -15,31 +17,18 @@
1517
)
1618

1719
from executorch.exir import ExportedProgram
18-
from executorch.exir.dialects._ops import ops as exir_ops
1920

2021

2122
def is_update_cache_node(node: Any) -> bool:
22-
if not hasattr(node, "target"):
23-
return False
23+
return utils.node_has_target(node, "llama::update_cache")
2424

25-
if isinstance(node.target, str):
26-
return node.target == "llama::update_cache"
27-
elif hasattr(node.target, "name"):
28-
return node.target.name() == "llama::update_cache"
29-
else:
30-
return False
3125

26+
def is_custom_sdpa_node(node: Any) -> bool:
27+
return utils.node_has_target(node, "llama::custom_sdpa")
3228

33-
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
34-
if not hasattr(node, "target"):
35-
return False
3629

37-
if isinstance(node.target, str):
38-
return "sdpa_with_kv_cache" in node.target
39-
elif hasattr(node.target, "name"):
40-
return "sdpa_with_kv_cache" in node.target.name()
41-
else:
42-
return False
30+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
31+
return utils.node_has_target(node, "llama::sdpa_with_kv_cache")
4332

4433

4534
class CausalSDPAMatch(PatternMatch):
@@ -97,7 +86,7 @@ def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
9786
def find_causal_sdpa_patterns(
9887
node: torch.fx.Node,
9988
) -> Optional[CausalSDPAMatch]:
100-
if node.target != exir_ops.edge.llama.custom_sdpa.default:
89+
if not is_custom_sdpa_node(node):
10190
return None
10291

10392
matched_pattern = CausalSDPAMatch(node)

backends/vulkan/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
373373
return None
374374

375375

376+
def node_has_target(node: Any, target: str):
377+
if not hasattr(node, "target"):
378+
return False
379+
380+
if isinstance(node.target, str):
381+
return node.target == target
382+
elif hasattr(node.target, "name"):
383+
return node.target.name() == target
384+
385+
return False
386+
387+
376388
##
377389
## Memory Layout, Storage Type Determination
378390
##

0 commit comments

Comments
 (0)