Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions backends/vulkan/_passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ runtime.python_library(
]
)

runtime.python_library(
name = "remove_asserts",
srcs = ["remove_asserts.py"],
visibility = [
"//executorch/backends/...",
],
deps = [
"//caffe2:torch",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
],
)

runtime.python_library(
name = "remove_local_scalar_dense",
srcs = ["remove_local_scalar_dense_ops.py"],
Expand Down Expand Up @@ -83,6 +96,7 @@ runtime.python_library(
deps = [
":insert_prepack_nodes",
":int4_weight_only_quantizer",
":remove_asserts",
":remove_local_scalar_dense",
":remove_redundant_ops",
":tag_memory_meta_pass"
Expand Down
6 changes: 6 additions & 0 deletions backends/vulkan/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
VkInt4WeightOnlyQuantizer,
)
from executorch.backends.vulkan._passes.remove_asserts import (
remove_asserts,
RemoveAssertsTransform,
)
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
RemoveLocalScalarDenseOpsTransform,
)
Expand All @@ -13,6 +17,8 @@
__all__ = [
"insert_prepack_nodes",
"VkInt4WeightOnlyQuantizer",
"remove_asserts",
"RemoveAssertsTransform",
"RemoveLocalScalarDenseOpsTransform",
"RemoveRedundantOpsTransform",
"TagMemoryMetaPass",
Expand Down
52 changes: 52 additions & 0 deletions backends/vulkan/_passes/remove_asserts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Set, Union

import torch

from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.program._program import _get_updated_graph_signature

from torch.export.exported_program import ExportedProgram

OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]


class RemoveAssertsTransform(ExportPass):
"""
Remove operators which perform assertions. These are not possible to execute in
Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
operators.
"""

assert_ops: Set[OpType] = {
torch.ops.aten._assert_scalar.default,
torch.ops.aten.sym_constrain_range_for_size.default,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.target in self.assert_ops:
graph_module.graph.erase_node(node)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)


def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
graph_module = edge_program.graph_module
RemoveAssertsTransform()(graph_module)

edge_program._graph_signature = _get_updated_graph_signature(
edge_program.graph_signature, graph_module
)
edge_program._validate()
return edge_program
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):


# TODO(ssjia) allow registration after remove assertions pass is implemented
# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
@update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
def register_sdpa_ops(features: OpFeatures):
features.texture_impl = TextureImplFeatures(
valid_packed_dims={PackedDim.WIDTH},
Expand Down
5 changes: 5 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import pkg_resources
import torch

from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
from executorch.devtools.backend_debug import get_delegation_info

from executorch.devtools.etrecord import generate_etrecord
Expand Down Expand Up @@ -727,6 +729,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
)
modelname = f"vulkan_{modelname}"

# Need to remove asserts from the graph to prevent graph breaks
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())

if args.mps:
partitioners.append(get_mps_partitioner(args.use_kv_cache))
modelname = f"mps_{modelname}"
Expand Down
Loading