diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 74048cfb6a7..4e60fc7bd7e 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -30,6 +30,19 @@ runtime.python_library( ] ) +runtime.python_library( + name = "remove_asserts", + srcs = ["remove_asserts.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -83,6 +96,7 @@ runtime.python_library( deps = [ ":insert_prepack_nodes", ":int4_weight_only_quantizer", + ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", ":tag_memory_meta_pass" diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 416339574ba..8c29f5488f3 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -2,6 +2,10 @@ from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, ) +from executorch.backends.vulkan._passes.remove_asserts import ( + remove_asserts, + RemoveAssertsTransform, +) from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) @@ -13,6 +17,8 @@ __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", + "remove_asserts", + "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "TagMemoryMetaPass", diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 7876806d6d1..bf1fc28ba56 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -60,6 +60,12 @@ def prepack_not_required(node: torch.fx.Node) -> bool: ) # This pass assumes that the SpecPropPass() has already been applied assert "spec" in node.meta + # Mutable buffers will not be marked as constant, but it might as well be + # for the purposes of memory planning. Mark it as a constant tensor so that + # it is handled correctly by the memory planning pass. + if not node.meta["spec"].const: + assert is_param_node(program, node) + node.meta["spec"].const = True # Validate that the original node is marked as a constant. Constant tensors # do not participate in memory planning. assert node.meta["spec"].const @@ -68,7 +74,9 @@ def prepack_not_required(node: torch.fx.Node) -> bool: # Set the mem_obj_id to -1 to indicate that this node requires a dedicated # memory object. prepack_node.meta["spec"].mem_obj_id = -1 - node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) + node.replace_all_uses_with( + prepack_node, lambda x, y=prepack_node: (x != y and x.op != "output") + ) program.graph.eliminate_dead_code() return program diff --git a/backends/vulkan/_passes/remove_asserts.py b/backends/vulkan/_passes/remove_asserts.py new file mode 100644 index 00000000000..835f2ec1415 --- /dev/null +++ b/backends/vulkan/_passes/remove_asserts.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.program._program import _get_updated_graph_signature + +from torch.export.exported_program import ExportedProgram + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + + +class RemoveAssertsTransform(ExportPass): + """ + Remove operators which perform assertions. These are not possible to execute in + Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these + operators. + """ + + assert_ops: Set[OpType] = { + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.target in self.assert_ops: + graph_module.graph.erase_node(node) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) + + +def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram: + graph_module = edge_program.graph_module + RemoveAssertsTransform()(graph_module) + + edge_program._graph_signature = _get_updated_graph_signature( + edge_program.graph_signature, graph_module + ) + edge_program._validate() + return edge_program diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 1d08817e26a..f2f54404ca8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -23,9 +23,6 @@ from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.tools_common import NodeList -from torch.fx.passes.utils.fuser_utils import topo_sort - logger: logging.Logger = logging.getLogger("") logger.setLevel(logging.INFO) @@ -220,9 +217,7 @@ def should_delay_annotation(self, node: torch.fx.Node) -> bool: # noqa def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes)) - - for node in sorted_nodes: + for node in graph_module.graph.nodes: if not self.should_annotate(node) or self.should_delay_annotation(node): continue diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index d70cf93b883..25cf74dc8f2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -478,7 +478,7 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") -def register_sdpa_op(features: OpFeatures): +def register_sdpa_with_kv_cache_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, ) @@ -489,6 +489,16 @@ def register_sdpa_op(features: OpFeatures): return features +@update_features(["llama::update_cache", "llama::custom_sdpa"]) +def register_sdpa_ops(features: OpFeatures): + features.resize_fn = False + features.buffer_impl = False + features.texture_impl = TextureImplFeatures( + valid_packed_dims={PackedDim.WIDTH}, + ) + return features + + @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 3c31e0316a6..6ff3fa8d70f 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -250,11 +250,19 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: self.log_skip(node, "local scalar dense of incompatible op node") return False + features = None if target not in vulkan_supported_ops: - self.log_skip(node, "no operator implementation") - return False + # For some ops, i.e. custom ops the name is registered instead of the + # OpOverload object. + if not isinstance(target, str) and target.name() in vulkan_supported_ops: + features = vulkan_supported_ops[target.name()] + else: + self.log_skip(node, "no operator implementation") + return False + else: + features = vulkan_supported_ops[target] - features = vulkan_supported_ops[target] + assert features is not None if not features.check_node_fn(node): self.log_skip(node, "op args not supported") diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 2c462013513..6dcf2fc4f45 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -176,17 +176,32 @@ void resize_sdpa_out( graph->get_tensor(out)->virtual_resize(graph->sizes_of(q_projected)); } -void sdpa_with_kv_cache_impl( - ComputeGraph& graph, - const std::vector& args) { +void update_cache_impl(ComputeGraph& graph, const std::vector& args) { + int arg_idx = 0; + const ValueRef value = args[arg_idx++]; + const ValueRef cache = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef out = args[arg_idx++]; + + // Unused variables + (void)out; + + VK_CHECK_COND(graph.size_at(-4, value) == 1); + VK_CHECK_COND(graph.size_at(-4, cache) == 1); + VK_CHECK_COND( + graph.size_at(-1, value) == graph.size_at(-1, cache)); + VK_CHECK_COND( + graph.size_at(-2, value) == graph.size_at(-2, cache)); + + add_kv_cache_update_node(graph, input_pos_symint, value, cache); +} + +void sdpa_impl(ComputeGraph& graph, const std::vector& args) { int arg_idx = 0; const ValueRef q_projected = args[arg_idx++]; - const ValueRef k_projected = args[arg_idx++]; - const ValueRef v_projected = args[arg_idx++]; - const ValueRef k_cache_data = args[arg_idx++]; - const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef k_cache = args[arg_idx++]; + const ValueRef v_cache = args[arg_idx++]; const ValueRef input_pos_symint = args[arg_idx++]; - const ValueRef sequence_len = args[arg_idx++]; const ValueRef attn_mask = args[arg_idx++]; const ValueRef dropout_p = args[arg_idx++]; const ValueRef is_causal = args[arg_idx++]; @@ -195,23 +210,20 @@ void sdpa_with_kv_cache_impl( // Output tensors const ValueRef out = args[arg_idx++]; - // Unused variables - (void)sequence_len; - // Batches must be 1 VK_CHECK_COND(graph.size_at(-4, q_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, k_projected) == 1); - VK_CHECK_COND(graph.size_at(-4, v_projected) == 1); + VK_CHECK_COND(graph.size_at(-4, k_cache) == 1); + VK_CHECK_COND(graph.size_at(-4, v_cache) == 1); // k and v projected must have the same shape - VK_CHECK_COND(graph.sizes_of(k_projected) == graph.sizes_of(v_projected)); + VK_CHECK_COND(graph.sizes_of(k_cache) == graph.sizes_of(v_cache)); // head dim must match between tensors VK_CHECK_COND( graph.size_at(-1, q_projected) == - graph.size_at(-1, k_projected)); + graph.size_at(-1, k_cache)); // All tensors must have the packed dim be the width (head) dimension VK_CHECK_COND(graph.packed_dim_of(q_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(k_projected) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(v_projected) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(k_cache) == WHCN::kWidthDim); + VK_CHECK_COND(graph.packed_dim_of(v_cache) == WHCN::kWidthDim); // Some variables are not supported yet VK_CHECK_COND( graph.val_is_none(dropout_p) || @@ -222,16 +234,8 @@ void sdpa_with_kv_cache_impl( graph.val_is_none(is_causal) || graph.extract_scalar(is_causal)); VK_CHECK_COND(graph.val_is_none(attn_mask)); - const ValueRef k_cache = - prepack_standard_like(graph, k_cache_data, q_projected); - const ValueRef v_cache = - prepack_standard_like(graph, v_cache_data, q_projected); - const int32_t max_seq_len = graph.size_at(1, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, k_projected, k_cache); - add_kv_cache_update_node(graph, input_pos_symint, v_projected, v_cache); - // Slice caches from 0 to input_pos + sequence_len const ValueRef k_cache_sliced = graph.add_tensor_view(k_cache); const ValueRef v_cache_sliced = graph.add_tensor_view(v_cache); @@ -257,7 +261,7 @@ void sdpa_with_kv_cache_impl( // Repeat interleave const int64_t num_heads = graph.size_at(2, q_projected); - const int64_t num_kv_heads = graph.size_at(2, k_projected); + const int64_t num_kv_heads = graph.size_at(2, k_cache); const ValueRef num_repeats = graph.add_scalar(num_heads / num_kv_heads); @@ -331,8 +335,52 @@ void sdpa_with_kv_cache_impl( new ExecuteNode(resize_sdpa_out, {q_projected, out})); } +void sdpa_with_kv_cache_impl( + ComputeGraph& graph, + const std::vector& args) { + int arg_idx = 0; + const ValueRef q_projected = args[arg_idx++]; + const ValueRef k_projected = args[arg_idx++]; + const ValueRef v_projected = args[arg_idx++]; + const ValueRef k_cache_data = args[arg_idx++]; + const ValueRef v_cache_data = args[arg_idx++]; + const ValueRef input_pos_symint = args[arg_idx++]; + const ValueRef sequence_len = args[arg_idx++]; + const ValueRef attn_mask = args[arg_idx++]; + const ValueRef dropout_p = args[arg_idx++]; + const ValueRef is_causal = args[arg_idx++]; + const ValueRef scale = args[arg_idx++]; + + // Output tensors + const ValueRef out = args[arg_idx++]; + + (void)sequence_len; + + const ValueRef k_cache = + prepack_standard_like(graph, k_cache_data, q_projected); + const ValueRef v_cache = + prepack_standard_like(graph, v_cache_data, q_projected); + + update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); + update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); + + sdpa_impl( + graph, + {q_projected, + k_cache, + v_cache, + input_pos_symint, + attn_mask, + dropout_p, + is_causal, + scale, + out}); +} + REGISTER_OPERATORS { VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl); + VK_REGISTER_OP(update_cache.default, update_cache_impl); + VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl); } } // namespace vkcompute diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index df6c930eb48..618c74e8706 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -21,6 +21,8 @@ import pkg_resources import torch + +from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.etrecord import generate_etrecord @@ -727,6 +729,10 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 ) modelname = f"vulkan_{modelname}" + # Need to remove asserts from the graph to prevent graph breaks + # pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`. + remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) + if args.mps: partitioners.append(get_mps_partitioner(args.use_kv_cache)) modelname = f"mps_{modelname}"