diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index ed3d8479331..74048cfb6a7 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -43,6 +43,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "remove_redundant_ops", + srcs = ["remove_redundant_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "tag_memory_meta_pass", srcs = ["tag_memory_meta_pass.py"], @@ -71,6 +84,7 @@ runtime.python_library( ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", + ":remove_redundant_ops", ":tag_memory_meta_pass" ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 8823553ab13..416339574ba 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -5,11 +5,15 @@ from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) +from executorch.backends.vulkan._passes.remove_redundant_ops import ( + RemoveRedundantOpsTransform, +) from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", + "RemoveRedundantOpsTransform", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py new file mode 100644 index 00000000000..530505f7003 --- /dev/null +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + + +class RemoveRedundantOpsTransform(ExportPass): + """ + Trim certain operators to reduce unnecessary overhead. + """ + + redundant_ops: Set[OpType] = { + torch.clone, + torch.ops.aten.clone.default, + exir_ops.edge.aten.clone.default, + torch.ops.aten.alias.default, + exir_ops.edge.aten.alias.default, + exir_ops.edge.aten.lift_fresh_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + } + + def __init__(self) -> None: + super(RemoveRedundantOpsTransform, self).__init__() + + def _should_remove(self, node: torch.fx.Node) -> bool: + if node.target in self.redundant_ops: + return True + + # Only remove to_copy if dtype does not change. Otherwise, memory format changes + # will be handled internally by the backend. + if ( + node.target == exir_ops.edge.aten._to_copy.default + or node.target == torch.ops.aten._to_copy.default + ): + src_dtype = node.meta["val"].dtype + # pyre-ignore + dst_dtype = node.args[0].meta["val"].dtype + return src_dtype == dst_dtype + + return False + + def _remove(self, graph_module: torch.fx.GraphModule) -> None: + for node in graph_module.graph.nodes: + if not self._should_remove(node): + continue + + with graph_module.graph.inserting_after(node): + node.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._remove(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index eeec5ab37e6..eb831e352c2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,6 +228,8 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + # dim order copy operator will be removed; memory layout is handled internally + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index c938f9ff424..6e406a10ba6 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,11 +17,11 @@ from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform -from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, + RemoveRedundantOpsTransform, TagMemoryMetaPass, ) @@ -143,7 +143,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - RemoveCloneOpsTransform(), + RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseDequantLinearPass(), FuseViewCopyTransform(),