Skip to content

Commit 64cdf41

Browse files
author
ssjia
committed
Update on "[ET-VK] Implementation of to_dim_order_copy"
Title says it all! Previously, to_dim_order_copy was handled by removing the op. However, this is not possible if the op is modifying the dtype of the original tensor, so these instances of the op would be skipped by the partitioner. This diff adds an implementation dtype conversion, which allows to_dim_order_copy to be lowered. Differential Revision: [D86340341](https://our.internmc.facebook.com/intern/diff/D86340341/) [ghstack-poisoned]
2 parents db6d764 + 1f826ea commit 64cdf41

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

backends/vulkan/patterns/sdpa.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from typing import Any, Optional
88

9+
import executorch.backends.vulkan.utils as utils
10+
911
import torch
1012

1113
from executorch.backends.vulkan.patterns.pattern_registry import (
@@ -15,31 +17,18 @@
1517
)
1618

1719
from executorch.exir import ExportedProgram
18-
from executorch.exir.dialects._ops import ops as exir_ops
1920

2021

2122
def is_update_cache_node(node: Any) -> bool:
22-
if not hasattr(node, "target"):
23-
return False
23+
return utils.node_has_target(node, "llama::update_cache")
2424

25-
if isinstance(node.target, str):
26-
return node.target == "llama::update_cache"
27-
elif hasattr(node.target, "name"):
28-
return node.target.name() == "llama::update_cache"
29-
else:
30-
return False
3125

26+
def is_custom_sdpa_node(node: Any) -> bool:
27+
return utils.node_has_target(node, "llama::custom_sdpa")
3228

33-
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
34-
if not hasattr(node, "target"):
35-
return False
3629

37-
if isinstance(node.target, str):
38-
return "sdpa_with_kv_cache" in node.target
39-
elif hasattr(node.target, "name"):
40-
return "sdpa_with_kv_cache" in node.target.name()
41-
else:
42-
return False
30+
def is_sdpa_with_kv_cache_node(node: Any) -> bool:
31+
return utils.node_has_target(node, "llama::sdpa_with_kv_cache")
4332

4433

4534
class CausalSDPAMatch(PatternMatch):
@@ -97,7 +86,7 @@ def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
9786
def find_causal_sdpa_patterns(
9887
node: torch.fx.Node,
9988
) -> Optional[CausalSDPAMatch]:
100-
if node.target != exir_ops.edge.llama.custom_sdpa.default:
89+
if not is_custom_sdpa_node(node):
10190
return None
10291

10392
matched_pattern = CausalSDPAMatch(node)

backends/vulkan/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,18 @@ def find_quant_user(node: torch.fx.Node) -> Optional[torch.fx.Node]:
373373
return None
374374

375375

376+
def node_has_target(node: Any, target: str):
377+
if not hasattr(node, "target"):
378+
return False
379+
380+
if isinstance(node.target, str):
381+
return node.target == target
382+
elif hasattr(node.target, "name"):
383+
return node.target.name() == target
384+
385+
return False
386+
387+
376388
##
377389
## Memory Layout, Storage Type Determination
378390
##

0 commit comments

Comments
 (0)