diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 74048cfb6a7..4e60fc7bd7e 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -30,6 +30,19 @@ runtime.python_library( ] ) +runtime.python_library( + name = "remove_asserts", + srcs = ["remove_asserts.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -83,6 +96,7 @@ runtime.python_library( deps = [ ":insert_prepack_nodes", ":int4_weight_only_quantizer", + ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", ":tag_memory_meta_pass" diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 416339574ba..8c29f5488f3 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -2,6 +2,10 @@ from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, ) +from executorch.backends.vulkan._passes.remove_asserts import ( + remove_asserts, + RemoveAssertsTransform, +) from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) @@ -13,6 +17,8 @@ __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", + "remove_asserts", + "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", "RemoveRedundantOpsTransform", "TagMemoryMetaPass", diff --git a/backends/vulkan/_passes/remove_asserts.py b/backends/vulkan/_passes/remove_asserts.py new file mode 100644 index 00000000000..835f2ec1415 --- /dev/null +++ b/backends/vulkan/_passes/remove_asserts.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch + +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.program._program import _get_updated_graph_signature + +from torch.export.exported_program import ExportedProgram + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + + +class RemoveAssertsTransform(ExportPass): + """ + Remove operators which perform assertions. These are not possible to execute in + Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these + operators. + """ + + assert_ops: Set[OpType] = { + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, + } + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + if node.target in self.assert_ops: + graph_module.graph.erase_node(node) + + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) + + +def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram: + graph_module = edge_program.graph_module + RemoveAssertsTransform()(graph_module) + + edge_program._graph_signature = _get_updated_graph_signature( + edge_program.graph_signature, graph_module + ) + edge_program._validate() + return edge_program diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b8eca829047..81b1b4c8fa2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -490,7 +490,7 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): # TODO(ssjia) allow registration after remove assertions pass is implemented -# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default]) +@update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default]) def register_sdpa_ops(features: OpFeatures): features.texture_impl = TextureImplFeatures( valid_packed_dims={PackedDim.WIDTH}, diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index df6c930eb48..ccda4b76bbd 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -21,6 +21,8 @@ import pkg_resources import torch + +from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import get_delegation_info from executorch.devtools.etrecord import generate_etrecord @@ -727,6 +729,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901 ) modelname = f"vulkan_{modelname}" + # Need to remove asserts from the graph to prevent graph breaks + remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) + if args.mps: partitioners.append(get_mps_partitioner(args.use_kv_cache)) modelname = f"mps_{modelname}"