diff --git a/backends/vulkan/TARGETS b/backends/vulkan/TARGETS index 3f966bd2ff4..8932d08b76f 100644 --- a/backends/vulkan/TARGETS +++ b/backends/vulkan/TARGETS @@ -28,6 +28,7 @@ runtime.python_library( "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:mean_to_sum_div", "//executorch/backends/transforms:remove_clone_ops", + "//executorch/backends/vulkan/passes:remove_local_scalar_dense", "//executorch/exir:graph_module", "//executorch/exir/_serialize:_bindings", "//executorch/exir/_serialize:lib", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 103297bc758..b1046aa01f0 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -108,6 +108,31 @@ def is_linear_permute(self, node: torch.fx.Node) -> bool: return False + def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: + """ + Scalar tensors are usually converted to scalar values in the graph via` + scalar_tensor[0].item()` in Python, which translates to a chain of + `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. + This function marks the entire chain as supported by the Vulkan delegate. + + Later, within vulkan_preprocess there will be a graph transform which + replaces the chain with passing in the scalar tensor directly. + """ + if node.target == exir_ops.edge.aten.select_copy.int: + if len(node.users) != 1: + return False + # pyre-ignore + if node.args[0].meta["val"].numel() != 1: + return False + + user = list(node.users.keys())[0] + return user.target == torch.ops.aten._local_scalar_dense.default + + if node.target == torch.ops.aten._local_scalar_dense.default: + return True + + return False + def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: @@ -122,6 +147,9 @@ def _is_node_supported( if self.is_linear_permute(node): return True + if self.is_in_local_scalar_dense_chain(node): + return True + if node.target not in VulkanSupportedOperators._ops: return False diff --git a/backends/vulkan/passes/TARGETS b/backends/vulkan/passes/TARGETS index 62024907661..afdbe4ed033 100644 --- a/backends/vulkan/passes/TARGETS +++ b/backends/vulkan/passes/TARGETS @@ -27,3 +27,16 @@ python_unittest( "//caffe2:torch", ], ) + +runtime.python_library( + name = "remove_local_scalar_dense", + srcs = ["remove_local_scalar_dense_ops.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + ], +) diff --git a/backends/vulkan/passes/remove_local_scalar_dense_ops.py b/backends/vulkan/passes/remove_local_scalar_dense_ops.py new file mode 100644 index 00000000000..0f710764988 --- /dev/null +++ b/backends/vulkan/passes/remove_local_scalar_dense_ops.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +def remove_local_scalar_dense_ops(graph: torch.fx.Graph) -> torch.fx.Graph: + """ + Remove local_scalar_dense op nodes and replace uses with parent node, or the + original scalar tensor. + """ + target_op = torch.ops.aten._local_scalar_dense.default + for node in graph.nodes: + if node.op == "call_function" and node.target == target_op: + replace_node = node.args[0] + # If the argument to the local_scalar_dense op is a select op with only + # one user, and the argument to the select op is a tensor with only one + # element (i.e. a scalar tensor), then replace the entire pattern with the + # scalar tensor. + if ( + replace_node.op == "call_function" + and replace_node.target == exir_ops.edge.aten.select_copy.int + ): + if replace_node.args[0].meta["val"].numel() == 1: + replace_node = replace_node.args[0] + + with graph.inserting_after(node): + node.replace_all_uses_with(replace_node) + + graph.eliminate_dead_code() + return graph + + +class RemoveLocalScalarDenseOpsTransform(ExportPass): + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph_module.graph = remove_local_scalar_dense_ops(graph_module.graph) + return PassResult(graph_module, True) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 7e85c25faee..562457ac54f 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,6 +17,10 @@ from executorch.backends.transforms.mean_to_sum_div import MeanToSumDiv from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform +from executorch.backends.vulkan.passes.remove_local_scalar_dense_ops import ( + RemoveLocalScalarDenseOpsTransform, +) + from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( serialize_vulkan_graph, @@ -57,6 +61,7 @@ def preprocess( # noqa: C901 MeanToSumDiv(), SpecPropPass(), ConstraintBasedSymShapeEvalPass(), + RemoveLocalScalarDenseOpsTransform(), MemoryPlanningPass(), ]