From c1214bdeb07ac6756af27a7eaa310a865209f69f Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 30 May 2025 08:26:16 -0700 Subject: [PATCH] [ET-VK][ez] Fix handling of assert ops ## Changes * Apply `RemoveAssertsTransform` as part of `vulkan_preprocess` * Do not call `RemoveAssertsTransform` before lowering the graph * Register ops related to asserts to the operator registry as ephemeral ops ## Motivation assert ops are not implemented in Vulkan, so previously `RemoveAssertsTransform()` is called on the graph before the lowering process. However, it turns out that the assertion ops are required to properly handle dynamic shapes, because they place constraints on the possible range of symbolic integers. If they are not present, then re-tracing the graph during a recompile (which may occur during a graph transform pass) may fail. Therefore, instead of calling the transform before lowering, call it inside vulkan_preprocess after a point where subsequent passes will not attempt to trace the graph. Differential Revision: [D75686048](https://our.internmc.facebook.com/intern/diff/D75686048/) [ghstack-poisoned] --- backends/vulkan/_passes/fuse_quantized_ops.py | 5 ++++- backends/vulkan/_passes/tag_memory_meta_pass.py | 1 - backends/vulkan/op_registry.py | 7 +++++++ backends/vulkan/partitioner/vulkan_partitioner.py | 7 ++++--- backends/vulkan/serialization/vulkan_graph_builder.py | 2 +- backends/vulkan/vulkan_preprocess.py | 2 ++ examples/models/llama/TARGETS | 1 - examples/models/llama/export_llama_lib.py | 4 ---- 8 files changed, 18 insertions(+), 11 deletions(-) diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index d510e1d4342..805a5c1f744 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -17,6 +17,7 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass ################# ## linear_qcnw ## @@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) graph_module.recompile() - graph_module = super().call(graph_module).graph_module + dead_code_elimination_pass(graph_module) + # Re-trace the graph since new nodes were (potentially) inserted + graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 667de2ae45f..b38d0e220fe 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -from copy import deepcopy from typing import Any, Optional, Set import executorch.backends.vulkan.utils as utils diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 3c1f1eb40dc..90fea61318c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -231,6 +231,13 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + # Guard and assert ops + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index d690e886d40..cbf30f84196 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: - if utils.is_symint_node(node): - return node.target in vulkan_supported_ops, "Op is compatible" - elif utils.is_tensor_node(node): + if utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) + # For non-tensor nodes, just check if the op is registered + elif hasattr(node, "target"): + return node.target in vulkan_supported_ops, "Op is compatible" return False, f"Unsupported node type: {node.format_node()}" diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index 62a15a22ede..78cc47bfb6d 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -353,7 +353,7 @@ def process_call_function_node(self, node) -> None: # previously encountered, then use the existing Value id. operator_call_args.append(self.get_or_create_value_for(function_arg)) else: - for i, arg_node in enumerate(node.args): + for _, arg_node in enumerate(node.args): operator_call_args.append(self.get_or_create_value_for(arg_node)) # Add output node diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 4200df3e131..a22afc3f42e 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -29,6 +29,7 @@ SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) +from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( @@ -172,6 +173,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + RemoveAssertsTransform(), # Since this pass may replace a scalar argument with a tensor argument, # this pass may result in a non ATen compliant graph structure. RemoveLocalScalarDenseOpsTransform(), diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index f2aa396f7a1..872eccce872 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -148,7 +148,6 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", - "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a3102886f8..96faf64475e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -24,7 +24,6 @@ import pkg_resources import torch -from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func @@ -880,9 +879,6 @@ def _to_edge_and_lower_llama( # noqa: C901 ) modelname = f"vulkan_{modelname}" - # Need to remove asserts from the graph to prevent graph breaks - remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) - if mps: partitioners.append(get_mps_partitioner(use_kv_cache)) modelname = f"mps_{modelname}"