Skip to content

Commit ca7cdc8

Browse files
authored
Split SDPA + KV cache operator into SDPA operator and KV cache update operator + Add RemoveAsserts pass and apply it during LlaMa export
Differential Revision: D68922404 Pull Request resolved: #8075
1 parent e8ee36c commit ca7cdc8

File tree

9 files changed

+184
-37
lines changed

9 files changed

+184
-37
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ runtime.python_library(
3030
]
3131
)
3232

33+
runtime.python_library(
34+
name = "remove_asserts",
35+
srcs = ["remove_asserts.py"],
36+
visibility = [
37+
"//executorch/backends/...",
38+
],
39+
deps = [
40+
"//caffe2:torch",
41+
"//executorch/exir:pass_base",
42+
"//executorch/exir/dialects:lib",
43+
],
44+
)
45+
3346
runtime.python_library(
3447
name = "remove_local_scalar_dense",
3548
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -83,6 +96,7 @@ runtime.python_library(
8396
deps = [
8497
":insert_prepack_nodes",
8598
":int4_weight_only_quantizer",
99+
":remove_asserts",
86100
":remove_local_scalar_dense",
87101
":remove_redundant_ops",
88102
":tag_memory_meta_pass"

backends/vulkan/_passes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
33
VkInt4WeightOnlyQuantizer,
44
)
5+
from executorch.backends.vulkan._passes.remove_asserts import (
6+
remove_asserts,
7+
RemoveAssertsTransform,
8+
)
59
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
610
RemoveLocalScalarDenseOpsTransform,
711
)
@@ -13,6 +17,8 @@
1317
__all__ = [
1418
"insert_prepack_nodes",
1519
"VkInt4WeightOnlyQuantizer",
20+
"remove_asserts",
21+
"RemoveAssertsTransform",
1622
"RemoveLocalScalarDenseOpsTransform",
1723
"RemoveRedundantOpsTransform",
1824
"TagMemoryMetaPass",

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6060
)
6161
# This pass assumes that the SpecPropPass() has already been applied
6262
assert "spec" in node.meta
63+
# Mutable buffers will not be marked as constant, but it might as well be
64+
# for the purposes of memory planning. Mark it as a constant tensor so that
65+
# it is handled correctly by the memory planning pass.
66+
if not node.meta["spec"].const:
67+
assert is_param_node(program, node)
68+
node.meta["spec"].const = True
6369
# Validate that the original node is marked as a constant. Constant tensors
6470
# do not participate in memory planning.
6571
assert node.meta["spec"].const
@@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
6874
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
6975
# memory object.
7076
prepack_node.meta["spec"].mem_obj_id = -1
71-
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)
77+
node.replace_all_uses_with(
78+
prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output")
79+
)
7280

7381
program.graph.eliminate_dead_code()
7482
return program
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set, Union
10+
11+
import torch
12+
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.program._program import _get_updated_graph_signature
16+
17+
from torch.export.exported_program import ExportedProgram
18+
19+
OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
20+
21+
22+
class RemoveAssertsTransform(ExportPass):
23+
"""
24+
Remove operators which perform assertions. These are not possible to execute in
25+
Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
26+
operators.
27+
"""
28+
29+
assert_ops: Set[OpType] = {
30+
torch.ops.aten._assert_scalar.default,
31+
torch.ops.aten.sym_constrain_range_for_size.default,
32+
}
33+
34+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
35+
for node in graph_module.graph.nodes:
36+
if node.target in self.assert_ops:
37+
graph_module.graph.erase_node(node)
38+
39+
graph_module.graph.eliminate_dead_code()
40+
graph_module.recompile()
41+
return PassResult(graph_module, True)
42+
43+
44+
def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
45+
graph_module = edge_program.graph_module
46+
RemoveAssertsTransform()(graph_module)
47+
48+
edge_program._graph_signature = _get_updated_graph_signature(
49+
edge_program.graph_signature, graph_module
50+
)
51+
edge_program._validate()
52+
return edge_program

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
2525

26-
from torch.fx.passes.tools_common import NodeList
27-
from torch.fx.passes.utils.fuser_utils import topo_sort
28-
2926
logger: logging.Logger = logging.getLogger("")
3027
logger.setLevel(logging.INFO)
3128

@@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool:
220217

221218
# noqa
222219
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
223-
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
224-
225-
for node in sorted_nodes:
220+
for node in graph_module.graph.nodes:
226221
if not self.should_annotate(node) or self.should_delay_annotation(node):
227222
continue
228223

backends/vulkan/op_registry.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures):
478478

479479

480480
@update_features("llama::sdpa_with_kv_cache")
481-
def register_sdpa_op(features: OpFeatures):
481+
def register_sdpa_with_kv_cache_op(features: OpFeatures):
482482
features.texture_impl = TextureImplFeatures(
483483
valid_packed_dims={PackedDim.WIDTH},
484484
)
@@ -489,6 +489,16 @@ def register_sdpa_op(features: OpFeatures):
489489
return features
490490

491491

492+
@update_features(["llama::update_cache", "llama::custom_sdpa"])
493+
def register_sdpa_ops(features: OpFeatures):
494+
features.resize_fn = False
495+
features.buffer_impl = False
496+
features.texture_impl = TextureImplFeatures(
497+
valid_packed_dims={PackedDim.WIDTH},
498+
)
499+
return features
500+
501+
492502
@update_features(exir_ops.edge.et_vk.apply_rotary_emb.default)
493503
def register_rotary_emb_op(features: OpFeatures):
494504
features.texture_impl = TextureImplFeatures(

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,11 +250,19 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool:
250250
self.log_skip(node, "local scalar dense of incompatible op node")
251251
return False
252252

253+
features = None
253254
if target not in vulkan_supported_ops:
254-
self.log_skip(node, "no operator implementation")
255-
return False
255+
# For some ops, i.e. custom ops the name is registered instead of the
256+
# OpOverload object.
257+
if not isinstance(target, str) and target.name() in vulkan_supported_ops:
258+
features = vulkan_supported_ops[target.name()]
259+
else:
260+
self.log_skip(node, "no operator implementation")
261+
return False
262+
else:
263+
features = vulkan_supported_ops[target]
256264

257-
features = vulkan_supported_ops[target]
265+
assert features is not None
258266

259267
if not features.check_node_fn(node):
260268
self.log_skip(node, "op args not supported")

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 74 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -176,17 +176,32 @@ void resize_sdpa_out(
176176
graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected));
177177
}
178178

179-
void sdpa_with_kv_cache_impl(
180-
ComputeGraph& graph,
181-
const std::vector<ValueRef>& args) {
179+
void update_cache_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
180+
int arg_idx = 0;
181+
const ValueRef value = args[arg_idx++];
182+
const ValueRef cache = args[arg_idx++];
183+
const ValueRef input_pos_symint = args[arg_idx++];
184+
const ValueRef out = args[arg_idx++];
185+
186+
// Unused variables
187+
(void)out;
188+
189+
VK_CHECK_COND(graph.size_at<int32_t>(-4, value) == 1);
190+
VK_CHECK_COND(graph.size_at<int32_t>(-4, cache) == 1);
191+
VK_CHECK_COND(
192+
graph.size_at<int32_t>(-1, value) == graph.size_at<int32_t>(-1, cache));
193+
VK_CHECK_COND(
194+
graph.size_at<int32_t>(-2, value) == graph.size_at<int32_t>(-2, cache));
195+
196+
add_kv_cache_update_node(graph, input_pos_symint, value, cache);
197+
}
198+
199+
void sdpa_impl(ComputeGraph& graph, const std::vector<ValueRef>& args) {
182200
int arg_idx = 0;
183201
const ValueRef q_projected = args[arg_idx++];
184-
const ValueRef k_projected = args[arg_idx++];
185-
const ValueRef v_projected = args[arg_idx++];
186-
const ValueRef k_cache_data = args[arg_idx++];
187-
const ValueRef v_cache_data = args[arg_idx++];
202+
const ValueRef k_cache = args[arg_idx++];
203+
const ValueRef v_cache = args[arg_idx++];
188204
const ValueRef input_pos_symint = args[arg_idx++];
189-
const ValueRef sequence_len = args[arg_idx++];
190205
const ValueRef attn_mask = args[arg_idx++];
191206
const ValueRef dropout_p = args[arg_idx++];
192207
const ValueRef is_causal = args[arg_idx++];
@@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl(
195210
// Output tensors
196211
const ValueRef out = args[arg_idx++];
197212

198-
// Unused variables
199-
(void)sequence_len;
200-
201213
// Batches must be 1
202214
VK_CHECK_COND(graph.size_at<int32_t>(-4, q_projected) == 1);
203-
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_projected) == 1);
204-
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_projected) == 1);
215+
VK_CHECK_COND(graph.size_at<int32_t>(-4, k_cache) == 1);
216+
VK_CHECK_COND(graph.size_at<int32_t>(-4, v_cache) == 1);
205217
// k and v projected must have the same shape
206-
VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected));
218+
VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache));
207219
// head dim must match between tensors
208220
VK_CHECK_COND(
209221
graph.size_at<int32_t>(-1, q_projected) ==
210-
graph.size_at<int32_t>(-1, k_projected));
222+
graph.size_at<int32_t>(-1, k_cache));
211223
// All tensors must have the packed dim be the width (head) dimension
212224
VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim);
213-
VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim);
214-
VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim);
225+
VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim);
226+
VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim);
215227
// Some variables are not supported yet
216228
VK_CHECK_COND(
217229
graph.val_is_none(dropout_p) ||
@@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl(
222234
graph.val_is_none(is_causal) || graph.extract_scalar<bool>(is_causal));
223235
VK_CHECK_COND(graph.val_is_none(attn_mask));
224236

225-
const ValueRef k_cache =
226-
prepack_standard_like(graph, k_cache_data, q_projected);
227-
const ValueRef v_cache =
228-
prepack_standard_like(graph, v_cache_data, q_projected);
229-
230237
const int32_t max_seq_len = graph.size_at<int32_t>(1, k_cache);
231238

232-
add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache);
233-
add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache);
234-
235239
// Slice caches from 0 to input_pos + sequence_len
236240
const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache);
237241
const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache);
@@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl(
257261

258262
// Repeat interleave
259263
const int64_t num_heads = graph.size_at<int64_t>(2, q_projected);
260-
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_projected);
264+
const int64_t num_kv_heads = graph.size_at<int64_t>(2, k_cache);
261265

262266
const ValueRef num_repeats =
263267
graph.add_scalar<int64_t>(num_heads / num_kv_heads);
@@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl(
331335
new ExecuteNode(resize_sdpa_out, {q_projected, out}));
332336
}
333337

338+
void sdpa_with_kv_cache_impl(
339+
ComputeGraph& graph,
340+
const std::vector<ValueRef>& args) {
341+
int arg_idx = 0;
342+
const ValueRef q_projected = args[arg_idx++];
343+
const ValueRef k_projected = args[arg_idx++];
344+
const ValueRef v_projected = args[arg_idx++];
345+
const ValueRef k_cache_data = args[arg_idx++];
346+
const ValueRef v_cache_data = args[arg_idx++];
347+
const ValueRef input_pos_symint = args[arg_idx++];
348+
const ValueRef sequence_len = args[arg_idx++];
349+
const ValueRef attn_mask = args[arg_idx++];
350+
const ValueRef dropout_p = args[arg_idx++];
351+
const ValueRef is_causal = args[arg_idx++];
352+
const ValueRef scale = args[arg_idx++];
353+
354+
// Output tensors
355+
const ValueRef out = args[arg_idx++];
356+
357+
(void)sequence_len;
358+
359+
const ValueRef k_cache =
360+
prepack_standard_like(graph, k_cache_data, q_projected);
361+
const ValueRef v_cache =
362+
prepack_standard_like(graph, v_cache_data, q_projected);
363+
364+
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
365+
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
366+
367+
sdpa_impl(
368+
graph,
369+
{q_projected,
370+
k_cache,
371+
v_cache,
372+
input_pos_symint,
373+
attn_mask,
374+
dropout_p,
375+
is_causal,
376+
scale,
377+
out});
378+
}
379+
334380
REGISTER_OPERATORS {
335381
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
382+
VK_REGISTER_OP(update_cache.default, update_cache_impl);
383+
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
336384
}
337385

338386
} // namespace vkcompute

examples/models/llama/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import pkg_resources
2323
import torch
24+
25+
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2426
from executorch.devtools.backend_debug import get_delegation_info
2527

2628
from executorch.devtools.etrecord import generate_etrecord
@@ -727,6 +729,10 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
727729
)
728730
modelname = f"vulkan_{modelname}"
729731

732+
# Need to remove asserts from the graph to prevent graph breaks
733+
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
734+
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
735+
730736
if args.mps:
731737
partitioners.append(get_mps_partitioner(args.use_kv_cache))
732738
modelname = f"mps_{modelname}"

0 commit comments

Comments
 (0)