Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ runtime.python_library(
],
)

runtime.python_library(
name = "fuse_patterns",
srcs = ["fuse_patterns.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/patterns:vulkan_patterns",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
typing = True,
)

runtime.python_library(
name = "vulkan_passes",
srcs = [
Expand All @@ -128,6 +144,7 @@ runtime.python_library(
"//executorch/examples/...",
],
deps = [
":fuse_patterns",
":fuse_quantized_ops",
":insert_prepack_nodes",
":int4_weight_only_quantizer",
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
FuseQuantizedOpsTransform,
)
Expand All @@ -29,6 +30,7 @@
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass

__all__ = [
"FusePatternsPass",
"FuseQuantizedOpsTransform",
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
Expand Down
126 changes: 126 additions & 0 deletions backends/vulkan/_passes/fuse_patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import operator
from typing import Callable, List, Optional

import executorch.backends.vulkan.patterns as vk_patterns

import torch

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher


def fuse_pattern(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
patterns: List[torch.fx.GraphModule],
create_replacement_func: Callable,
) -> int:
total_replaced = 0

for pattern in patterns:
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
matches = list(sm.match(graph_module.graph))

for partition_to_replace in matches:
create_replacement_func(ep, graph_module, partition_to_replace)
total_replaced += 1
# Remove dead code so they won't be matched again
graph_module.graph.eliminate_dead_code()

return total_replaced


##
## Rotary Embedding
##


def identify_rotary_emb_io_nodes(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: InternalMatch,
) -> Optional[List[torch.fx.Node]]:
# Get the input placeholders (xq, xk, freqs_cos, freqs_sin)
placeholder_nodes = match.placeholder_nodes
if len(placeholder_nodes) != 4:
return None

xq, xk, freqs_cos, freqs_sin = placeholder_nodes

output_nodes = match.returning_nodes
if len(output_nodes) != 2:
return None

xq_out, xk_out = output_nodes

return [xq, xk, freqs_cos, freqs_sin, xq_out, xk_out]


def create_rotary_emb_custom_op(
ep: ExportedProgram,
graph_module: torch.fx.GraphModule,
match: InternalMatch,
):
io_nodes = identify_rotary_emb_io_nodes(ep, graph_module, match)
if io_nodes is None:
return

assert len(io_nodes) == 6
xq, xk, freqs_cos, freqs_sin, xq_out, xk_out = io_nodes

# Create the custom op node
with graph_module.graph.inserting_before(xq_out):
rotary_emb_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.et_vk.apply_rotary_emb.default,
args=(xq, xk, freqs_cos, freqs_sin),
)

# The custom op returns a tuple (xq_out, xk_out)
# We need to extract the individual outputs
with graph_module.graph.inserting_after(rotary_emb_node):
getitem_0 = graph_module.graph.create_node(
"call_function",
operator.getitem,
args=(rotary_emb_node, 0),
)
getitem_1 = graph_module.graph.create_node(
"call_function",
operator.getitem,
args=(rotary_emb_node, 1),
)

if hasattr(xq_out, "meta") and "val" in xq_out.meta:
getitem_0.meta["val"] = xq_out.meta["val"]
if hasattr(xk_out, "meta") and "val" in xk_out.meta:
getitem_1.meta["val"] = xk_out.meta["val"]

xq_out.replace_all_uses_with(getitem_0)
xk_out.replace_all_uses_with(getitem_1)


class FusePatternsPass(ExportPass):
def __init__(self, exported_program: ExportedProgram) -> None:
super().__init__()
self.program = exported_program

def call(self, graph_module: torch.fx.GraphModule):
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
self.program, graph_module
)

if total_replaced > 0:
graph_module.recompile()
# Re-trace the graph
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, total_replaced > 0)
36 changes: 3 additions & 33 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

namespace = "et_vk"
Expand Down Expand Up @@ -325,42 +326,11 @@ def linear_qta8a_qga4w(
######################


# Note that this implementation is copied from executorch.examples.models.llama.rope
# but it is copied here to avoid introducing a dependency on the llama code.
def apply_rotary_emb_impl(
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
):
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
freqs_cis_ndim = freqs_cis.ndim
if freqs_cis_ndim == 3:
# freqs_cis: (seq_len, n_heads, head_dim // 2)
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
shape = [
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
for i, d in enumerate(x.shape)
]
else:
# freqs_cis: (seq_len, head_dim // 2)
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(shape)

xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)
pattern = vk_patterns.RotaryEmbeddingPattern()
return pattern.forward(xq, xk, freqs_cos, freqs_sin)


name = "apply_rotary_emb"
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def update_features_impl(op: OpKey):
operator.gt,
operator.ge,
operator.le,
operator.eq,
# Guard and assert ops
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/partitioner/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ runtime.python_library(
"//executorch/backends/vulkan:op_registry",
"//executorch/backends/vulkan:utils_lib",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/backends/vulkan/patterns:vulkan_patterns",
"//executorch/exir:delegate",
"//executorch/exir:lib",
"//executorch/exir/backend:partitioner",
Expand Down
22 changes: 21 additions & 1 deletion backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple

import executorch.backends.vulkan.patterns as vk_patterns
import executorch.backends.vulkan.utils as utils

import torch
Expand Down Expand Up @@ -37,9 +38,10 @@
from executorch.exir.dialects._ops import ops as exir_ops

from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import InternalMatch

# pyre-ignore
ops_not_to_decompose = [
Expand All @@ -58,6 +60,7 @@ def __init__(
require_dynamic_shape: bool = False,
operator_blocklist: Optional[Set[OpKey]] = None,
operator_allowlist: Optional[Set[OpKey]] = None,
fusable_subgraphs: Optional[List[InternalMatch]] = None,
) -> None:
super().__init__()
self.texture_limits: utils.ImageExtents = texture_limits
Expand All @@ -67,6 +70,13 @@ def __init__(
operator_blocklist if operator_blocklist is not None else set()
)
self.operator_allowlist = operator_allowlist
self.fusable_subgraphs: List[InternalMatch] = (
fusable_subgraphs if fusable_subgraphs is not None else []
)
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
self.fusable_nodes: Set[torch.fx.Node] = set()
for match in self.fusable_subgraphs:
self.fusable_nodes.update(match.nodes_map.values())

def op_node_is_compatible( # noqa: C901: Function is too complex
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
Expand Down Expand Up @@ -204,6 +214,10 @@ def is_node_supported(
return r

def _is_node_supported(self, node: torch.fx.Node) -> bool:
# Check if this node is part of a fusable subgraph
if node.op == "call_function" and node in self.fusable_nodes:
return True

target = node.target
if node.target == torch.ops.higher_order.auto_functionalized:
first_arg = node.args[0]
Expand Down Expand Up @@ -330,6 +344,11 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
# subgraphs containing the nodes with the tags
partition_tags = {}

# Get all fusable subgraphs from fuse_patterns
fusable_subgraphs = vk_patterns.get_all_fusable_subgraphs(
exported_program.graph_module
)

texture_limits: utils.ImageExtents = self.options.get(
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
)
Expand All @@ -342,6 +361,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
operator_blocklist=self.operator_blocklist,
operator_allowlist=self.operator_allowlist,
fusable_subgraphs=fusable_subgraphs,
),
allows_single_node_partition=True,
)
Expand Down
24 changes: 24 additions & 0 deletions backends/vulkan/patterns/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "vulkan_patterns",
srcs = [
"__init__.py",
"pattern_registry.py",
"rope.py",
],
visibility = [
"//executorch/backends/...",
"//executorch/examples/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/backends/transforms:utils",
"//executorch/backends/vulkan:utils_lib",
],
typing = True,
)
Loading
Loading