Skip to content

Commit 4438d31

Browse files
pytorchbotssjia
andauthored
[ET-VK] Move rotary embedding custom op to be handled via graph pass instead of source transform (#13465)
## Motivation Be able to test Vulkan lowering via optimum-executorch. ## Context Currently, ET-VK implements rotary embeddings via a custom op. This op is currently inserted into Transformer models by replacing Rotary Embedding modules with a custom module that executes the custom op via a source transform. The source transform approach makes it cumbersome to lower LLMs to Vulkan, since it requires the export logic to apply the source transform before calling `torch.export()`. This in turn makes it difficult to integrate Vulkan lowering into optimum-executorch, which tries to use a common export + lowering logic for all lowering paths. As an alternative, leverage `SubgraphMatcher` to detect fusable patterns and fuse the rotary embedding graph pattern into the custom op as part of the Vulkan delegate's graph passes. This removes the requirement to apply a custom source transform just for Vulkan. ## Changes * Introduce the `backends/vulkan/patterns` folder to store fusable graph patterns * Introduce a fusable graph pattern for rotary positional embeddings * Update partitioner logic to automatically include nodes that are part of a fusable graph pattern * Introduce a pass to fuse known patterns into custom ops / custom op sequence Differential Revision: [D80293301](https://our.internmc.facebook.com/intern/diff/D80293301/) Co-authored-by: ssjia <[email protected]>
1 parent 7fbca4d commit 4438d31

File tree

18 files changed

+534
-76
lines changed

18 files changed

+534
-76
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,22 @@ runtime.python_library(
118118
],
119119
)
120120

121+
runtime.python_library(
122+
name = "fuse_patterns",
123+
srcs = ["fuse_patterns.py"],
124+
visibility = [
125+
"//executorch/backends/...",
126+
],
127+
deps = [
128+
"//caffe2:torch",
129+
"//executorch/backends/vulkan/patterns:vulkan_patterns",
130+
"//executorch/exir:lib",
131+
"//executorch/exir:pass_base",
132+
"//executorch/exir/dialects:lib",
133+
],
134+
typing = True,
135+
)
136+
121137
runtime.python_library(
122138
name = "vulkan_passes",
123139
srcs = [
@@ -128,6 +144,7 @@ runtime.python_library(
128144
"//executorch/examples/...",
129145
],
130146
deps = [
147+
":fuse_patterns",
131148
":fuse_quantized_ops",
132149
":insert_prepack_nodes",
133150
":int4_weight_only_quantizer",

backends/vulkan/_passes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
# pyre-strict
88

9+
from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass
910
from executorch.backends.vulkan._passes.fuse_quantized_ops import (
1011
FuseQuantizedOpsTransform,
1112
)
@@ -29,6 +30,7 @@
2930
from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass
3031

3132
__all__ = [
33+
"FusePatternsPass",
3234
"FuseQuantizedOpsTransform",
3335
"insert_prepack_nodes",
3436
"VkInt4WeightOnlyQuantizer",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import executorch.backends.vulkan.patterns as vk_patterns
8+
9+
import torch
10+
11+
from executorch.exir import ExportedProgram
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
class FusePatternsPass(ExportPass):
16+
def __init__(self, exported_program: ExportedProgram) -> None:
17+
super().__init__()
18+
self.program = exported_program
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
22+
self.program, graph_module
23+
)
24+
25+
if total_replaced > 0:
26+
graph_module.recompile()
27+
# Re-trace the graph
28+
graph_module = super().call(graph_module).graph_module
29+
30+
return PassResult(graph_module, total_replaced > 0)

backends/vulkan/custom_ops_lib.py

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import executorch.backends.vulkan.patterns as vk_patterns
78
import torch.library
89

910
namespace = "et_vk"
@@ -325,42 +326,11 @@ def linear_qta8a_qga4w(
325326
######################
326327

327328

328-
# Note that this implementation is copied from executorch.examples.models.llama.rope
329-
# but it is copied here to avoid introducing a dependency on the llama code.
330329
def apply_rotary_emb_impl(
331330
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
332331
):
333-
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
334-
ndim = x.ndim
335-
freqs_cis_ndim = freqs_cis.ndim
336-
if freqs_cis_ndim == 3:
337-
# freqs_cis: (seq_len, n_heads, head_dim // 2)
338-
assert freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1])
339-
shape = [
340-
d if (i == ndim - 3 or i == ndim - 2 or i == ndim - 1) else 1
341-
for i, d in enumerate(x.shape)
342-
]
343-
else:
344-
# freqs_cis: (seq_len, head_dim // 2)
345-
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
346-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
347-
return freqs_cis.view(shape)
348-
349-
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
350-
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
351-
352-
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
353-
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
354-
355-
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
356-
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
357-
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
358-
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
359-
360-
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
361-
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
362-
363-
return xq_out.type_as(xq), xk_out.type_as(xk)
332+
pattern = vk_patterns.RotaryEmbeddingPattern()
333+
return pattern.forward(xq, xk, freqs_cos, freqs_sin)
364334

365335

366336
name = "apply_rotary_emb"

backends/vulkan/op_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def update_features_impl(op: OpKey):
125125
operator.gt,
126126
operator.ge,
127127
operator.le,
128+
operator.eq,
128129
# Guard and assert ops
129130
torch.ops.aten._assert_scalar.default,
130131
torch.ops.aten.sym_constrain_range_for_size.default,

backends/vulkan/partitioner/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ runtime.python_library(
1515
"//executorch/backends/vulkan:op_registry",
1616
"//executorch/backends/vulkan:utils_lib",
1717
"//executorch/backends/vulkan:vulkan_preprocess",
18+
"//executorch/backends/vulkan/patterns:vulkan_patterns",
1819
"//executorch/exir:delegate",
1920
"//executorch/exir:lib",
2021
"//executorch/exir/backend:partitioner",

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from typing import Any, Callable, Dict, final, List, Mapping, Optional, Set, Tuple
1111

12+
import executorch.backends.vulkan.patterns as vk_patterns
1213
import executorch.backends.vulkan.utils as utils
1314

1415
import torch
@@ -37,9 +38,10 @@
3738
from executorch.exir.dialects._ops import ops as exir_ops
3839

3940
from torch.export.exported_program import ExportedProgram
40-
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4141

42+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
4243
from torch.fx.passes.operator_support import OperatorSupportBase
44+
from torch.fx.passes.utils.matcher_utils import InternalMatch
4345

4446
# pyre-ignore
4547
ops_not_to_decompose = [
@@ -58,6 +60,7 @@ def __init__(
5860
require_dynamic_shape: bool = False,
5961
operator_blocklist: Optional[Set[OpKey]] = None,
6062
operator_allowlist: Optional[Set[OpKey]] = None,
63+
fusable_subgraphs: Optional[List[InternalMatch]] = None,
6164
) -> None:
6265
super().__init__()
6366
self.texture_limits: utils.ImageExtents = texture_limits
@@ -67,6 +70,13 @@ def __init__(
6770
operator_blocklist if operator_blocklist is not None else set()
6871
)
6972
self.operator_allowlist = operator_allowlist
73+
self.fusable_subgraphs: List[InternalMatch] = (
74+
fusable_subgraphs if fusable_subgraphs is not None else []
75+
)
76+
# Create a set of all nodes that are part of fusable subgraphs for quick lookup
77+
self.fusable_nodes: Set[torch.fx.Node] = set()
78+
for match in self.fusable_subgraphs:
79+
self.fusable_nodes.update(match.nodes_map.values())
7080

7181
def op_node_is_compatible( # noqa: C901: Function is too complex
7282
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
@@ -204,6 +214,10 @@ def is_node_supported(
204214
return r
205215

206216
def _is_node_supported(self, node: torch.fx.Node) -> bool:
217+
# Check if this node is part of a fusable subgraph
218+
if node.op == "call_function" and node in self.fusable_nodes:
219+
return True
220+
207221
target = node.target
208222
if node.target == torch.ops.higher_order.auto_functionalized:
209223
first_arg = node.args[0]
@@ -330,6 +344,11 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
330344
# subgraphs containing the nodes with the tags
331345
partition_tags = {}
332346

347+
# Get all fusable subgraphs from fuse_patterns
348+
fusable_subgraphs = vk_patterns.get_all_fusable_subgraphs(
349+
exported_program.graph_module
350+
)
351+
333352
texture_limits: utils.ImageExtents = self.options.get(
334353
"texture_limits", utils.DEFAULT_TEXTURE_LIMITS
335354
)
@@ -342,6 +361,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
342361
require_dynamic_shape=self.options.get("require_dynamic_shapes", False),
343362
operator_blocklist=self.operator_blocklist,
344363
operator_allowlist=self.operator_allowlist,
364+
fusable_subgraphs=fusable_subgraphs,
345365
),
346366
allows_single_node_partition=True,
347367
)

backends/vulkan/patterns/TARGETS

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
3+
4+
oncall("executorch")
5+
6+
runtime.python_library(
7+
name = "vulkan_patterns",
8+
srcs = [
9+
"__init__.py",
10+
"pattern_registry.py",
11+
"rope.py",
12+
],
13+
visibility = [
14+
"//executorch/backends/...",
15+
"//executorch/examples/...",
16+
],
17+
deps = [
18+
"//caffe2:torch",
19+
"//executorch/exir:lib",
20+
"//executorch/backends/transforms:utils",
21+
"//executorch/backends/vulkan:utils_lib",
22+
],
23+
typing = True,
24+
)

backends/vulkan/patterns/__init__.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
import executorch.backends.vulkan.patterns.rope # noqa
10+
11+
import torch
12+
13+
from executorch.backends.vulkan.patterns.pattern_registry import (
14+
CreateReplacementFn,
15+
fusable_patterns,
16+
GetGraphFn,
17+
register_pattern_graph,
18+
register_pattern_replacement,
19+
)
20+
21+
from executorch.backends.vulkan.patterns.rope import RotaryEmbeddingPattern
22+
23+
from executorch.exir import ExportedProgram
24+
25+
from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher
26+
27+
28+
__all__ = [
29+
"GetGraphFn",
30+
"CreateReplacementFn",
31+
"RotaryEmbeddingPattern",
32+
"fusable_patterns",
33+
"register_pattern_graph",
34+
"register_pattern_replacement",
35+
]
36+
37+
38+
def all_fusable_graph_patterns() -> List[torch.fx.GraphModule]:
39+
all_patterns = []
40+
for entry in fusable_patterns.values():
41+
if entry.get_graphs_fn is not None:
42+
all_patterns.extend(entry.get_graphs_fn())
43+
44+
return all_patterns
45+
46+
47+
def get_all_fusable_subgraphs(
48+
graph_module: torch.fx.GraphModule,
49+
) -> List[InternalMatch]:
50+
fusable_subgraphs = []
51+
52+
fuse_patterns = all_fusable_graph_patterns()
53+
for pattern in fuse_patterns:
54+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
55+
matches = list(sm.match(graph_module.graph))
56+
fusable_subgraphs.extend(matches)
57+
58+
return fusable_subgraphs
59+
60+
61+
def create_replacement_for_pattern(
62+
ep: ExportedProgram,
63+
graph_module: torch.fx.GraphModule,
64+
patterns: List[torch.fx.GraphModule],
65+
create_replacement_func: CreateReplacementFn,
66+
) -> int:
67+
total_replaced = 0
68+
69+
for pattern in patterns:
70+
sm = SubgraphMatcher(pattern.graph, ignore_literals=True)
71+
matches = list(sm.match(graph_module.graph))
72+
73+
for partition_to_replace in matches:
74+
create_replacement_func(ep, graph_module, partition_to_replace)
75+
total_replaced += 1
76+
# Remove dead code so they won't be matched again
77+
graph_module.graph.eliminate_dead_code()
78+
79+
return total_replaced
80+
81+
82+
def replace_all_fusable_subgraphs(
83+
ep: ExportedProgram,
84+
graph_module: torch.fx.GraphModule,
85+
) -> int:
86+
total_replaced = 0
87+
88+
for entry in fusable_patterns.values():
89+
if entry.get_graphs_fn is not None and entry.create_replacement_fn is not None:
90+
total_replaced += create_replacement_for_pattern(
91+
ep,
92+
graph_module,
93+
entry.get_graphs_fn(),
94+
# pyre-ignore[6]
95+
entry.create_replacement_fn,
96+
)
97+
98+
return total_replaced

0 commit comments

Comments
 (0)