Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ runtime.python_library(
"quantized_linear.py",
"quantized_convolution.py",
"quantized_binary.py",
"sdpa.py",
"select_as_symint.py",
],
visibility = [
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/patterns/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import executorch.backends.vulkan.patterns.rope # noqa

import executorch.backends.vulkan.patterns.sdpa # noqa

import executorch.backends.vulkan.patterns.select_as_symint # noqa

import torch
Expand Down
166 changes: 166 additions & 0 deletions backends/vulkan/patterns/sdpa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

import torch

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops


def is_update_cache_node(node: Any) -> bool:
if not hasattr(node, "target"):
return False

if isinstance(node.target, str):
return node.target == "llama::update_cache"
elif hasattr(node.target, "name"):
return node.target.name() == "llama::update_cache"
else:
return False


def is_sdpa_with_kv_cache_node(node: Any) -> bool:
if not hasattr(node, "target"):
return False

if isinstance(node.target, str):
return "sdpa_with_kv_cache" in node.target
elif hasattr(node.target, "name"):
return "sdpa_with_kv_cache" in node.target.name()
else:
return False


class CausalSDPAMatch(PatternMatch):
def __init__(self, custom_sdpa_node: torch.fx.Node) -> None:
self.anchor_node = custom_sdpa_node
self.match_found = False
self.all_nodes = [self.anchor_node]

# llama.custom_sdpa has signature:
# custom_sdpa(query, key_cache, value_cache, start_pos, attn_mask, dropout_p, is_causal, scale) -> output
if len(custom_sdpa_node.args) < 4:
return

self.query_node = custom_sdpa_node.args[0]
self.key_cache_node = custom_sdpa_node.args[1]
self.value_cache_node = custom_sdpa_node.args[2]
self.start_pos_node = custom_sdpa_node.args[3]
self.attn_mask_node = custom_sdpa_node.args[4]
self.dropout_p_node = custom_sdpa_node.args[5]
self.is_causal_node = custom_sdpa_node.args[6]
if len(custom_sdpa_node.args) > 7:
self.scale_node = custom_sdpa_node.args[7]
else:
self.scale_node = None

# try to find update key cache node
self.update_key_cache_node = None
for user in self.key_cache_node.users:
if is_update_cache_node(user):
self.update_key_cache_node = user
break

self.key_projection_node = None
if self.update_key_cache_node is not None:
self.key_projection_node = self.update_key_cache_node.args[0]

# find update value cache node
self.update_value_cache_node = None
for user in self.value_cache_node.users:
if is_update_cache_node(user):
self.update_value_cache_node = user
break

self.value_projection_node = None
if self.update_value_cache_node is not None:
self.value_projection_node = self.update_value_cache_node.args[0]

# We have additional optional arguments but we don't need to capture them
# since the new op doesn't use them

self.match_found = True


@register_pattern_detector("causal_sdpa")
def find_causal_sdpa_patterns(
node: torch.fx.Node,
) -> Optional[CausalSDPAMatch]:
if node.target != exir_ops.edge.llama.custom_sdpa.default:
return None

matched_pattern = CausalSDPAMatch(node)
if matched_pattern.match_found:
return matched_pattern

return None


##
## Pattern Replacement
##


def find_singleton_start_pos_node(graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if is_update_cache_node(node):
return node.args[2]

if is_sdpa_with_kv_cache_node(node):
return node.args[5]

raise Exception(
"Could not find an instance of llama::update_cache or sdpa_with_kv_cache"
)


@register_pattern_replacement("causal_sdpa")
def replace_custom_sdpa_with_causal_sdpa(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: CausalSDPAMatch,
):
assert match.update_key_cache_node is not None
assert match.key_projection_node is not None
assert match.update_value_cache_node is not None
assert match.value_projection_node is not None

singleton_start_pos_node = find_singleton_start_pos_node(graph_module)

with graph_module.graph.inserting_before(match.anchor_node):
new_node = graph_module.graph.create_node(
"call_function",
torch.ops.llama.sdpa_with_kv_cache.default,
args=(
match.query_node,
match.key_projection_node,
match.value_projection_node,
match.key_cache_node,
match.value_cache_node,
singleton_start_pos_node,
1,
match.attn_mask_node,
match.dropout_p_node,
match.is_causal_node,
match.scale_node,
),
)

new_node.meta["val"] = match.anchor_node.meta["val"]
match.anchor_node.replace_all_uses_with(new_node)

# Manually erase update_cache nodes since DCE will not remove them since they
# modify inputs (specifically, the cache args are modified)
graph_module.graph.erase_node(match.update_key_cache_node)
graph_module.graph.erase_node(match.update_value_cache_node)
Loading