From 69d6d713be09e90d2cba324525750a6003723d7a Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 13 Dec 2024 13:09:52 -0800 Subject: [PATCH] [ET-VK] Add pass to remove copy ops ## Context This diff prepares Vulkan to handle dim order operators. For more context, see https://github.com/pytorch/executorch/issues/4873 Since Vulkan has its own internal representation of memory layout, these ops are handled by simply remove explicit memory layout transition operators from the graph and let the memory metadata tagging pass insert the necessary memory layout transitions. A new pass is added to remove such operators, largely based on QNN's `RemoveRedundancy` pass. Differential Revision: [D67180898](https://our.internmc.facebook.com/intern/diff/D67180898/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 14 ++++ backends/vulkan/_passes/__init__.py | 4 ++ .../vulkan/_passes/remove_redundant_ops.py | 69 +++++++++++++++++++ backends/vulkan/op_registry.py | 2 + backends/vulkan/vulkan_preprocess.py | 4 +- 5 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 backends/vulkan/_passes/remove_redundant_ops.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index ed3d8479331..74048cfb6a7 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -43,6 +43,19 @@ runtime.python_library( ], ) +runtime.python_library( + name = "remove_redundant_ops", + srcs = ["remove_redundant_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) + runtime.python_library( name = "tag_memory_meta_pass", srcs = ["tag_memory_meta_pass.py"], @@ -71,6 +84,7 @@ runtime.python_library( ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", + ":remove_redundant_ops", ":tag_memory_meta_pass" ] ) diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 8823553ab13..416339574ba 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -5,11 +5,15 @@ from executorch.backends.vulkan._passes.remove_local_scalar_dense_ops import ( RemoveLocalScalarDenseOpsTransform, ) +from executorch.backends.vulkan._passes.remove_redundant_ops import ( + RemoveRedundantOpsTransform, +) from executorch.backends.vulkan._passes.tag_memory_meta_pass import TagMemoryMetaPass __all__ = [ "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", + "RemoveRedundantOpsTransform", "TagMemoryMetaPass", ] diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py new file mode 100644 index 00000000000..530505f7003 --- /dev/null +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Set, Union + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.dialects.edge._ops import EdgeOpOverload +from executorch.exir.pass_base import ExportPass, PassResult +from executorch.exir.passes import dead_code_elimination_pass + +OpType = Union[str, torch._ops.OpOverload, EdgeOpOverload] + + +class RemoveRedundantOpsTransform(ExportPass): + """ + Trim certain operators to reduce unnecessary overhead. + """ + + redundant_ops: Set[OpType] = { + torch.clone, + torch.ops.aten.clone.default, + exir_ops.edge.aten.clone.default, + torch.ops.aten.alias.default, + exir_ops.edge.aten.alias.default, + exir_ops.edge.aten.lift_fresh_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + } + + def __init__(self) -> None: + super(RemoveRedundantOpsTransform, self).__init__() + + def _should_remove(self, node: torch.fx.Node) -> bool: + if node.target in self.redundant_ops: + return True + + # Only remove to_copy if dtype does not change. Otherwise, memory format changes + # will be handled internally by the backend. + if ( + node.target == exir_ops.edge.aten._to_copy.default + or node.target == torch.ops.aten._to_copy.default + ): + src_dtype = node.meta["val"].dtype + # pyre-ignore + dst_dtype = node.args[0].meta["val"].dtype + return src_dtype == dst_dtype + + return False + + def _remove(self, graph_module: torch.fx.GraphModule) -> None: + for node in graph_module.graph.nodes: + if not self._should_remove(node): + continue + + with graph_module.graph.inserting_after(node): + node.replace_all_uses_with(node.args[0]) + + graph_module.graph.eliminate_dead_code() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + self._remove(graph_module) + graph_module.recompile() + dead_code_elimination_pass(graph_module) + return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index eeec5ab37e6..eb831e352c2 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -228,6 +228,8 @@ def update_features_impl(op: OpKey): exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor, exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + # dim order copy operator will be removed; memory layout is handled internally + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, ] ) def register_ephemeral_op(features: OpFeatures): diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index c938f9ff424..6e406a10ba6 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,11 +17,11 @@ from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform -from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.backends.vulkan._passes import ( insert_prepack_nodes, RemoveLocalScalarDenseOpsTransform, + RemoveRedundantOpsTransform, TagMemoryMetaPass, ) @@ -143,7 +143,7 @@ def preprocess( # noqa: C901 program = apply_passes( program, [ - RemoveCloneOpsTransform(), + RemoveRedundantOpsTransform(), AddmmToLinearTransform(), FuseDequantLinearPass(), FuseViewCopyTransform(),