Skip to content

Commit 0b8a4c9

Browse files
committed
[ET-VK][ez] Fix handling of assert ops
Pull Request resolved: #11258 ## Changes * Apply `RemoveAssertsTransform` as part of `vulkan_preprocess` * Do not call `RemoveAssertsTransform` before lowering the graph * Register ops related to asserts to the operator registry as ephemeral ops ## Motivation assert ops are not implemented in Vulkan, so previously `RemoveAssertsTransform()` is called on the graph before the lowering process. However, it turns out that the assertion ops are required to properly handle dynamic shapes, because they place constraints on the possible range of symbolic integers. If they are not present, then re-tracing the graph during a recompile (which may occur during a graph transform pass) may fail. Therefore, instead of calling the transform before lowering, call it inside vulkan_preprocess after a point where subsequent passes will not attempt to trace the graph. ghstack-source-id: 287878314 @exported-using-ghexport Differential Revision: [D75686048](https://our.internmc.facebook.com/intern/diff/D75686048/)
1 parent 344bc14 commit 0b8a4c9

File tree

8 files changed

+18
-11
lines changed

8 files changed

+18
-11
lines changed

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from executorch.exir import ExportedProgram
1818
from executorch.exir.dialects._ops import ops as exir_ops
1919
from executorch.exir.pass_base import ExportPass, PassResult
20+
from executorch.exir.passes import dead_code_elimination_pass
2021

2122
#################
2223
## linear_qcnw ##
@@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
224225
)
225226

226227
graph_module.recompile()
227-
graph_module = super().call(graph_module).graph_module
228+
dead_code_elimination_pass(graph_module)
228229

230+
# Re-trace the graph since new nodes were (potentially) inserted
231+
graph_module = super().call(graph_module).graph_module
229232
return PassResult(graph_module, True)

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import logging
8-
from copy import deepcopy
98
from typing import Any, Optional, Set
109

1110
import executorch.backends.vulkan.utils as utils

backends/vulkan/op_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,13 @@ def update_features_impl(op: OpKey):
231231
# Symbolic integer ops
232232
torch.ops.aten.sym_size.int,
233233
operator.add,
234+
operator.lt,
235+
operator.gt,
236+
operator.ge,
237+
operator.le,
238+
# Guard and assert ops
239+
torch.ops.aten._assert_scalar.default,
240+
torch.ops.aten.sym_constrain_range_for_size.default,
234241
]
235242
)
236243
def register_ephemeral_op(features: OpFeatures):

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex
146146
def node_is_compatible(
147147
self, node: torch.fx.Node, features: Optional[OpFeatures] = None
148148
) -> Tuple[bool, str]:
149-
if utils.is_symint_node(node):
150-
return node.target in vulkan_supported_ops, "Op is compatible"
151-
elif utils.is_tensor_node(node):
149+
if utils.is_tensor_node(node):
152150
return self.op_node_is_compatible(node, features=features)
151+
# For non-tensor nodes, just check if the op is registered
152+
elif hasattr(node, "target"):
153+
return node.target in vulkan_supported_ops, "Op is compatible"
153154

154155
return False, f"Unsupported node type: {node.format_node()}"
155156

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def process_call_function_node(self, node) -> None:
353353
# previously encountered, then use the existing Value id.
354354
operator_call_args.append(self.get_or_create_value_for(function_arg))
355355
else:
356-
for i, arg_node in enumerate(node.args):
356+
for _, arg_node in enumerate(node.args):
357357
operator_call_args.append(self.get_or_create_value_for(arg_node))
358358

359359
# Add output node

backends/vulkan/vulkan_preprocess.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
SqueezeUnsqueezeInputs,
3030
TagMemoryMetaPass,
3131
)
32+
from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform
3233

3334
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
3435
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
@@ -172,6 +173,7 @@ def preprocess( # noqa: C901
172173
program = apply_passes(
173174
program,
174175
[
176+
RemoveAssertsTransform(),
175177
# Since this pass may replace a scalar argument with a tensor argument,
176178
# this pass may result in a non ATen compliant graph structure.
177179
RemoveLocalScalarDenseOpsTransform(),

examples/models/llama/TARGETS

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ runtime.python_library(
148148
":source_transformation",
149149
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
150150
"//caffe2:torch",
151-
"//executorch/backends/vulkan/_passes:vulkan_passes",
152151
"//executorch/exir/passes:init_mutable_pass",
153152
"//executorch/examples/models:model_base",
154153
"//executorch/examples/models:models",

examples/models/llama/export_llama_lib.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import pkg_resources
2525
import torch
2626

27-
from executorch.backends.vulkan._passes.remove_asserts import remove_asserts
2827
from executorch.devtools.backend_debug import print_delegation_info
2928

3029
from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func
@@ -880,9 +879,6 @@ def _to_edge_and_lower_llama( # noqa: C901
880879
)
881880
modelname = f"vulkan_{modelname}"
882881

883-
# Need to remove asserts from the graph to prevent graph breaks
884-
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
885-
886882
if mps:
887883
partitioners.append(get_mps_partitioner(use_kv_cache))
888884
modelname = f"mps_{modelname}"

0 commit comments

Comments
 (0)