|
6 | 6 |
|
7 | 7 | from typing import Any, Optional |
8 | 8 |
|
| 9 | +import executorch.backends.vulkan.utils as utils |
| 10 | + |
9 | 11 | import torch |
10 | 12 |
|
11 | 13 | from executorch.backends.vulkan.patterns.pattern_registry import ( |
|
15 | 17 | ) |
16 | 18 |
|
17 | 19 | from executorch.exir import ExportedProgram |
18 | | -from executorch.exir.dialects._ops import ops as exir_ops |
19 | 20 |
|
20 | 21 |
|
21 | 22 | def is_update_cache_node(node: Any) -> bool: |
22 | | - if not hasattr(node, "target"): |
23 | | - return False |
| 23 | + return utils.node_has_target(node, "llama::update_cache") |
24 | 24 |
|
25 | | - if isinstance(node.target, str): |
26 | | - return node.target == "llama::update_cache" |
27 | | - elif hasattr(node.target, "name"): |
28 | | - return node.target.name() == "llama::update_cache" |
29 | | - else: |
30 | | - return False |
31 | 25 |
|
| 26 | +def is_custom_sdpa_node(node: Any) -> bool: |
| 27 | + return utils.node_has_target(node, "llama::custom_sdpa") |
32 | 28 |
|
33 | | -def is_sdpa_with_kv_cache_node(node: Any) -> bool: |
34 | | - if not hasattr(node, "target"): |
35 | | - return False |
36 | 29 |
|
37 | | - if isinstance(node.target, str): |
38 | | - return "sdpa_with_kv_cache" in node.target |
39 | | - elif hasattr(node.target, "name"): |
40 | | - return "sdpa_with_kv_cache" in node.target.name() |
41 | | - else: |
42 | | - return False |
| 30 | +def is_sdpa_with_kv_cache_node(node: Any) -> bool: |
| 31 | + return utils.node_has_target(node, "llama::sdpa_with_kv_cache") |
43 | 32 |
|
44 | 33 |
|
45 | 34 | class CausalSDPAMatch(PatternMatch): |
@@ -97,7 +86,7 @@ def __init__(self, custom_sdpa_node: torch.fx.Node) -> None: |
97 | 86 | def find_causal_sdpa_patterns( |
98 | 87 | node: torch.fx.Node, |
99 | 88 | ) -> Optional[CausalSDPAMatch]: |
100 | | - if node.target != exir_ops.edge.llama.custom_sdpa.default: |
| 89 | + if not is_custom_sdpa_node(node): |
101 | 90 | return None |
102 | 91 |
|
103 | 92 | matched_pattern = CausalSDPAMatch(node) |
|
0 commit comments