diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index d510e1d4342..805a5c1f744 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -17,6 +17,7 @@ from executorch.exir import ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass ################# ## linear_qcnw ## @@ -224,6 +225,8 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) graph_module.recompile() - graph_module = super().call(graph_module).graph_module + dead_code_elimination_pass(graph_module) + # Re-trace the graph since new nodes were (potentially) inserted + graph_module = super().call(graph_module).graph_module return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 3c1f1eb40dc..90fea61318c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -231,6 +231,13 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.lt, + operator.gt, + operator.ge, + operator.le, + # Guard and assert ops + torch.ops.aten._assert_scalar.default, + torch.ops.aten.sym_constrain_range_for_size.default, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index d690e886d40..cbf30f84196 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -146,10 +146,11 @@ def op_node_is_compatible( # noqa: C901: Function is too complex def node_is_compatible( self, node: torch.fx.Node, features: Optional[OpFeatures] = None ) -> Tuple[bool, str]: - if utils.is_symint_node(node): - return node.target in vulkan_supported_ops, "Op is compatible" - elif utils.is_tensor_node(node): + if utils.is_tensor_node(node): return self.op_node_is_compatible(node, features=features) + # For non-tensor nodes, just check if the op is registered + elif hasattr(node, "target"): + return node.target in vulkan_supported_ops, "Op is compatible" return False, f"Unsupported node type: {node.format_node()}" diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 4200df3e131..a22afc3f42e 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -29,6 +29,7 @@ SqueezeUnsqueezeInputs, TagMemoryMetaPass, ) +from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( @@ -172,6 +173,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ + RemoveAssertsTransform(), # Since this pass may replace a scalar argument with a tensor argument, # this pass may result in a non ATen compliant graph structure. RemoveLocalScalarDenseOpsTransform(), diff --git a/examples/models/llama/TARGETS b/examples/models/llama/TARGETS index f2aa396f7a1..872eccce872 100644 --- a/examples/models/llama/TARGETS +++ b/examples/models/llama/TARGETS @@ -148,7 +148,6 @@ runtime.python_library( ":source_transformation", "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", - "//executorch/backends/vulkan/_passes:vulkan_passes", "//executorch/exir/passes:init_mutable_pass", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 3a3102886f8..96faf64475e 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -24,7 +24,6 @@ import pkg_resources import torch -from executorch.backends.vulkan._passes.remove_asserts import remove_asserts from executorch.devtools.backend_debug import print_delegation_info from executorch.devtools.etrecord import generate_etrecord as generate_etrecord_func @@ -880,9 +879,6 @@ def _to_edge_and_lower_llama( # noqa: C901 ) modelname = f"vulkan_{modelname}" - # Need to remove asserts from the graph to prevent graph breaks - remove_asserts(builder_exported_to_edge.edge_manager.exported_program()) - if mps: partitioners.append(get_mps_partitioner(use_kv_cache)) modelname = f"mps_{modelname}"