From c2ac9291916ce7c954884a90f8870a82eb7013b8 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 31 Jan 2025 09:29:18 -0800 Subject: [PATCH] Split SDPA + KV cache operator into SDPA operator and KV cache update operator + Add `RemoveAsserts` pass and apply it during LlaMa export (#8075) Summary: **Note**: This diff is a combination of D68919676 and D68919678. I decided to combine the two because of problems with `ghexport`, which was having some problems exporting the second diff, as well as the fact that both diffs are needed for `export_llama` to work so it makes more sense to just have a single diff. ## Context Recent changes split the `sdpa_with_kv_cache` operator into two separate operators, `update_cache` and `custom_sdpa` to decouple the cache update step from the actual SDPA computation. As a result, SDPA is no longer being delegated on Vulkan because of this interface change. To rectify this, Vulkan must also split `sdpa_with_kv_cache` into two operators. Note that during this diff the new operators are not partitioned yet because of complications caused by assertion ops in the graph. The next diff adds a pass to remove such assertion ops which allows the new operators to be partitioned. ## Context Recently, some assertion ops were added to the Llama source code. Unfortunately, this causes issues for the Vulkan delegate because runtime assertions are not yet supported in Vulkan and the assertion ops cause graph breaks due to not being supported. To prevent graph breaks when delegating to Vulkan, apply a pass to remove assertion ops during the llama export. Reviewed By: kimishpatel, digantdesai Differential Revision: D68922404 --- backends/vulkan/_passes/TARGETS | 14 +++ backends/vulkan/_passes/__init__.py | 6 ++ .../vulkan/_passes/insert_prepack_nodes.py | 10 +- backends/vulkan/_passes/remove_asserts.py | 52 +++++++++ .../vulkan/_passes/tag_memory_meta_pass.py | 7 +- backends/vulkan/op_registry.py | 12 ++- .../vulkan/partitioner/vulkan_partitioner.py | 14 ++- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 100 +++++++++++++----- examples/models/llama/export_llama_lib.py | 6 ++ 9 files changed, 184 insertions(+), 37 deletions(-) create mode 100644 backends/vulkan/_passes/remove_asserts.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 74048cfb6a7..4e60fc7bd7e 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -30,6 +30,19 @@ runtime.python_library( ] ) +runtime.python_library( + name = "remove_asserts", + srcs = ["remove_asserts.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -83,6 +96,7 @@ runtime.python_library( deps = [ ":insert_prepack_nodes", ":int4_weight_only_quantizer", + ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", ":tag_memory_meta_pass" diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 416339574ba..8c29f5488f3 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -2,6 +2,10 @@ from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, ) +from executorch.backends.vulkan._passes.remove_asserts import ( + remove_asserts, + RemoveAssertsTransform, +) from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) @@ -13,6 +17,8 @@ __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", + "remove_asserts", + "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "TagMemoryMetaPass", diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 7876806d6d1..bf1fc28ba56 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool: ) # This pass assumes that the SpecPropPass() has already been applied assert "spec" in node.meta + # Mutable buffers will not be marked as constant, but it might as well be + # for the purposes of memory planning. Mark it as a constant tensor so that + # it is handled correctly by the memory planning pass. + if not node.meta["spec"].const: + assert is_param_node(program, node) + node.meta["spec"].const = True # Validate that the original node is marked as a constant. Constant tensors # do not participate in memory planning. assert node.meta["spec"].const @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool: # Set the mem_obj_id to -1 to indicate that this node requires a dedicated # memory object. prepack_node.meta["spec"].mem_obj_id = -1 - node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) + node.replace_all_uses_with( + prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output") + ) program.graph.eliminate_dead_code() return program diff --git a/backends/vulkan/_passes/remove_asserts.py b/backends/vulkan/_passes/remove_asserts.py new file mode 100644 index 00000000000..835f2ec1415 --- /dev/null +++ b/backends/vulkan/_passes/remove_asserts.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.program._program import _get_updated_graph_signature + +from torch.export.exported_program import ExportedProgram + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + + +class RemoveAssertsTransform(ExportPass): + """ + Remove operators which perform assertions. These are not possible to execute in + Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these + operators. + """ + + assert_ops: Set[OpType] = { + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.target in self.assert_ops: + graph_module.graph.erase_node(node) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) + + +def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram: + graph_module = edge_program.graph_module + RemoveAssertsTransform()(graph_module) + + edge_program._graph_signature = _get_updated_graph_signature( + edge_program.graph_signature, graph_module + ) + edge_program._validate() + return edge_program diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 1d08817e26a..f2f54404ca8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -23,9 +23,6 @@ from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.tools_common import NodeList -from torch.fx.passes.utils.fuser_utils import topo_sort - logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool: # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) - - for node in sorted_nodes: + for node in graph_module.graph.nodes: if not self.should_annotate(node) or self.should_delay_annotation(node): continue diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index d70cf93b883..25cf74dc8f2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") -def register_sdpa_op(features: OpFeatures): +def register_sdpa_with_kv_cache_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, ) @@ -489,6 +489,16 @@ def register_sdpa_op(features: OpFeatures): return features +@update_features(["llama::update_cache", "llama::custom_sdpa"]) +def register_sdpa_ops(features: OpFeatures): + features.resize_fn = False + features.buffer_impl = False + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.WIDTH}, + ) + return features + + @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 3c31e0316a6..6ff3fa8d70f 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -250,11 +250,19 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: self.log_skip(node, "local scalar dense of incompatible op node") return False + features = None if target not in vulkan_supported_ops: - self.log_skip(node, "no operator implementation") - return False + # For some ops, i.e. custom ops the name is registered instead of the + # OpOverload object. + if not isinstance(target, str) and target.name() in vulkan_supported_ops: + features = vulkan_supported_ops[target.name()] + else: + self.log_skip(node, "no operator implementation") + return False + else: + features = vulkan_supported_ops[target] - features = vulkan_supported_ops[target] + assert features is not None if not features.check_node_fn(node): self.log_skip(node, "op args not supported") diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 2c462013513..6dcf2fc4f45 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -176,17 +176,32 @@ void resize_sdpa_out( graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); } -void sdpa_with_kv_cache_impl( - ComputeGraph& graph, - const std::vector& args) { +void update_cache_impl(ComputeGraph& graph, const std::vector& args) { + int arg_idx = 0; + const ValueRef value = args[arg_idx++]; + const ValueRef cache = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef out = args[arg_idx++]; + + // Unused variables + (void)out; + + VK_CHECK_COND(graph.size_at(-4, value) == 1); + VK_CHECK_COND(graph.size_at(-4, cache) == 1); + VK_CHECK_COND( + graph.size_at(-1, value) == graph.size_at(-1, cache)); + VK_CHECK_COND( + graph.size_at(-2, value) == graph.size_at(-2, cache)); + + add_kv_cache_update_node(graph, input_pos_symint, value, cache); +} + +void sdpa_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef q_projected = args[arg_idx++]; - const ValueRef k_projected = args[arg_idx++]; - const ValueRef v_projected = args[arg_idx++]; - const ValueRef k_cache_data = args[arg_idx++]; - const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef k_cache = args[arg_idx++]; + const ValueRef v_cache = args[arg_idx++]; const ValueRef input_pos_symint = args[arg_idx++]; - const ValueRef sequence_len = args[arg_idx++]; const ValueRef attn_mask = args[arg_idx++]; const ValueRef dropout_p = args[arg_idx++]; const ValueRef is_causal = args[arg_idx++]; @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl( // Output tensors const ValueRef out = args[arg_idx++]; - // Unused variables - (void)sequence_len; - // Batches must be 1 VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, k_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, v_projected) == 1); + VK_CHECK_COND(graph.size_at(-4, k_cache) == 1); + VK_CHECK_COND(graph.size_at(-4, v_cache) == 1); // k and v projected must have the same shape - VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected)); + VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache)); // head dim must match between tensors VK_CHECK_COND( graph.size_at(-1, q_projected) == - graph.size_at(-1, k_projected)); + graph.size_at(-1, k_cache)); // All tensors must have the packed dim be the width (head) dimension VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim); // Some variables are not supported yet VK_CHECK_COND( graph.val_is_none(dropout_p) || @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl( graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); VK_CHECK_COND(graph.val_is_none(attn_mask)); - const ValueRef k_cache = - prepack_standard_like(graph, k_cache_data, q_projected); - const ValueRef v_cache = - prepack_standard_like(graph, v_cache_data, q_projected); - const int32_t max_seq_len = graph.size_at(1, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache); - // Slice caches from 0 to input_pos + sequence_len const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache); const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache); @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl( // Repeat interleave const int64_t num_heads = graph.size_at(2, q_projected); - const int64_t num_kv_heads = graph.size_at(2, k_projected); + const int64_t num_kv_heads = graph.size_at(2, k_cache); const ValueRef num_repeats = graph.add_scalar(num_heads / num_kv_heads); @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl( new ExecuteNode(resize_sdpa_out, {q_projected, out})); } +void sdpa_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef dropout_p = args[arg_idx++]; + const ValueRef is_causal = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + const ValueRef k_cache = + prepack_standard_like(graph, k_cache_data, q_projected); + const ValueRef v_cache = + prepack_standard_like(graph, v_cache_data, q_projected); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + sdpa_impl( + graph, + {q_projected, + k_cache, + v_cache, + input_pos_symint, + attn_mask, + dropout_p, + is_causal, + scale, + out}); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); + VK_REGISTER_OP(update_cache.default, update_cache_impl); + VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); } } // namespace vkcompute diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index df6c930eb48..618c74e8706 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -21,6 +21,8 @@ import pkg_resources import torch + +from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.etrecord import generate_etrecord @@ -727,6 +729,10 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 ) modelname = f"vulkan_{modelname}" + # Need to remove asserts from the graph to prevent graph breaks + # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`. + remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) + if args.mps: partitioners.append(get_mps_partitioner(args.use_kv_cache)) modelname = f"mps_{modelname}"