Skip to content

Commit 632c0d0

Browse files
committed
[ET-VK][LlaMa] Add RemoveAsserts pass and apply it during LlaMa export
## Context Recently, some assertion ops were added to the Llama source code. Unfortunately, this causes issues for the Vulkan delegate because runtime assertions are not yet supported in Vulkan and the assertion ops cause graph breaks due to not being supported. To prevent graph breaks when delegating to Vulkan, apply a pass to remove assertion ops during the llama export. Differential Revision: [D68919678](https://our.internmc.facebook.com/intern/diff/D68919678/) [ghstack-poisoned]
1 parent db534b2 commit 632c0d0

File tree

5 files changed

+78
-1
lines changed

5 files changed

+78
-1
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ runtime.python_library(
3030
]
3131
)
3232

33+
runtime.python_library(
34+
name = "remove_asserts",
35+
srcs = ["remove_asserts.py"],
36+
visibility = [
37+
"//executorch/backends/...",
38+
],
39+
deps = [
40+
"//caffe2:torch",
41+
"//executorch/exir:pass_base",
42+
"//executorch/exir/dialects:lib",
43+
],
44+
)
45+
3346
runtime.python_library(
3447
name = "remove_local_scalar_dense",
3548
srcs = ["remove_local_scalar_dense_ops.py"],
@@ -83,6 +96,7 @@ runtime.python_library(
8396
deps = [
8497
":insert_prepack_nodes",
8598
":int4_weight_only_quantizer",
99+
":remove_asserts",
86100
":remove_local_scalar_dense",
87101
":remove_redundant_ops",
88102
":tag_memory_meta_pass"

backends/vulkan/_passes/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from executorch.backends.vulkan._passes.int4_weight_only_quantizer import (
33
VkInt4WeightOnlyQuantizer,
44
)
5+
from executorch.backends.vulkan._passes.remove_asserts import (
6+
remove_asserts,
7+
RemoveAssertsTransform,
8+
)
59
from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import (
610
RemoveLocalScalarDenseOpsTransform,
711
)
@@ -13,6 +17,8 @@
1317
__all__ = [
1418
"insert_prepack_nodes",
1519
"VkInt4WeightOnlyQuantizer",
20+
"remove_asserts",
21+
"RemoveAssertsTransform",
1622
"RemoveLocalScalarDenseOpsTransform",
1723
"RemoveRedundantOpsTransform",
1824
"TagMemoryMetaPass",
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from typing import Set, Union
10+
11+
import torch
12+
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
from executorch.exir.program._program import _get_updated_graph_signature
16+
17+
from torch.export.exported_program import ExportedProgram
18+
19+
OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload]
20+
21+
22+
class RemoveAssertsTransform(ExportPass):
23+
"""
24+
Remove operators which perform assertions. These are not possible to execute in
25+
Vulkan since GLSL shaders cannot abort execution at runtime. Therefore, remove these
26+
operators.
27+
"""
28+
29+
assert_ops: Set[OpType] = {
30+
torch.ops.aten._assert_scalar.default,
31+
torch.ops.aten.sym_constrain_range_for_size.default,
32+
}
33+
34+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
35+
for node in graph_module.graph.nodes:
36+
if node.target in self.assert_ops:
37+
graph_module.graph.erase_node(node)
38+
39+
graph_module.graph.eliminate_dead_code()
40+
graph_module.recompile()
41+
return PassResult(graph_module, True)
42+
43+
44+
def remove_asserts(edge_program: ExportedProgram) -> ExportedProgram:
45+
graph_module = edge_program.graph_module
46+
RemoveAssertsTransform()(graph_module)
47+
48+
edge_program._graph_signature = _get_updated_graph_signature(
49+
edge_program.graph_signature, graph_module
50+
)
51+
edge_program._validate()
52+
return edge_program

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures):
490490

491491

492492
# TODO(ssjia) allow registration after remove assertions pass is implemented
493-
# @update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
493+
@update_features(["llama::update_cache", exir_ops.edge.llama.custom_sdpa.default])
494494
def register_sdpa_ops(features: OpFeatures):
495495
features.texture_impl = TextureImplFeatures(
496496
valid_packed_dims={PackedDim.WIDTH},

examples/models/llama/export_llama_lib.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
import pkg_resources
2323
import torch
24+
25+
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2426
from executorch.devtools.backend_debug import get_delegation_info
2527

2628
from executorch.devtools.etrecord import generate_etrecord
@@ -727,6 +729,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
727729
)
728730
modelname = f"vulkan_{modelname}"
729731

732+
# Need to remove asserts from the graph to prevent graph breaks
733+
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
734+
730735
if args.mps:
731736
partitioners.append(get_mps_partitioner(args.use_kv_cache))
732737
modelname = f"mps_{modelname}"

0 commit comments

Comments
 (0)