From ff45acee32f648eafb10845f9833e6f96668e58d Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 18 Oct 2024 08:41:36 -0700 Subject: [PATCH] [ET-VK] Implement prepack nodes ## Context This diff implements the idea described in the previous diff in this stack. During export, `et_vk.prepack` nodes will be inserted to convert constant tensors to GPU tensor objects. This makes it so that Vulkan operators will not have to account for the possibility that their arguments can potentially be constant tensor data instead of an actual tensor object. Differential Revision: [D64603666](https://our.internmc.facebook.com/intern/diff/D64603666/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 13 +++ backends/vulkan/_passes/__init__.py | 2 + backends/vulkan/_passes/custom_ops_defs.py | 14 +++ .../vulkan/_passes/insert_prepack_nodes.py | 92 +++++++++++++++++++ .../runtime/graph/ops/impl/BinaryOp.cpp | 10 +- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 10 ++ backends/vulkan/test/test_vulkan_delegate.py | 10 +- backends/vulkan/vulkan_preprocess.py | 3 + 8 files changed, 144 insertions(+), 10 deletions(-) create mode 100644 backends/vulkan/_passes/insert_prepack_nodes.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 812c39c2b64..fa828640bf4 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -28,6 +28,18 @@ python_unittest( ], ) +runtime.python_library( + name = "insert_prepack_nodes", + srcs = ["insert_prepack_nodes.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -65,6 +77,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", ] diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 080df836080..cfdb7c6eeee 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -1,3 +1,4 @@ +from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, ) @@ -6,6 +7,7 @@ ) __all__ = [ + "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", ] diff --git a/backends/vulkan/_passes/custom_ops_defs.py b/backends/vulkan/_passes/custom_ops_defs.py index 2c16a331c04..4da2a31fc44 100644 --- a/backends/vulkan/_passes/custom_ops_defs.py +++ b/backends/vulkan/_passes/custom_ops_defs.py @@ -9,6 +9,20 @@ namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") +############# +## prepack ## +############# + + +def prepack_impl(x: torch.Tensor): + return x + + +name = "prepack" +lib.define(f"{name}(Tensor x) -> Tensor") +lib.impl(name, prepack_impl, "CompositeExplicitAutograd") +prepack_op = getattr(getattr(torch.ops, namespace), name) + ##################### ## conv_with_clamp ## ##################### diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py new file mode 100644 index 00000000000..5a49c587cbd --- /dev/null +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List + +import executorch.backends.vulkan._passes.custom_ops_defs # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops + +from torch._export.utils import is_buffer, is_param +from torch.export import ExportedProgram + +USES_WEIGHTS: List[torch._ops.OpOverload] = [ + exir_ops.edge.aten.embedding.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.et_vk.conv_with_clamp.default, + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten._weight_int8pack_mm.default, + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + "llama::sdpa_with_kv_cache", +] + + +def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: + """ + Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator + is responsible for transferring the tensor data, which is serialized with the model, + to a GPU tensor object during the prepacking stage of model execution. + + Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will + prefer to handle prepacking within the operator. For these ops, the constant tensor + data will be passed directly as an argument into the operator implementation. + """ + + def is_get_attr_node(node: torch.fx.Node) -> bool: + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + def is_constant(node: torch.fx.Node) -> bool: + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + def is_param_node(node: torch.fx.Node) -> bool: + """ + Check if the given node is a parameter within the exported program + """ + return ( + is_get_attr_node(node) + or is_param(program, node) + or is_buffer(program, node) + or is_constant(node) + ) + + def is_non_weight_param_tensor(node: torch.fx.Node) -> bool: + if not is_param_node(node): + return False + + for user in node.users: + if user.op == "call_function" and ( + # pyre-ignore [16] + user.target in USES_WEIGHTS + or user.target.name() in USES_WEIGHTS + ): + return False + + return True + + for node in program.graph_module.graph.nodes: + if not is_non_weight_param_tensor(node): + continue + + with program.graph_module.graph.inserting_after(node): + prepack_node = program.graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.prepack.default, + (node,), + ) + prepack_node.meta["spec"] = node.meta["spec"] + # Set the mem_obj_id to -1 to indicate that this node requires a dedicated + # memory object. This pass must be executed AFTER the memory planning pass. + prepack_node.meta["spec"].mem_obj_id = -1 + node.replace_all_uses_with(prepack_node, lambda x: x != prepack_node) + + program.graph.eliminate_dead_code() + return program diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index c055431a84b..d2e43e59880 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -51,11 +51,11 @@ void add_binary_op_node( const ValueRef alpha, const ValueRef out, const std::string& op_name) { - ValueRef arg1 = prepack_standard_like(graph, in1, out, true); - ValueRef arg2 = prepack_standard_like(graph, in2, out, true); + VK_CHECK_COND(graph.val_is_tensor(in1)); + VK_CHECK_COND(graph.val_is_tensor(in2)); - vTensorPtr t_in1 = graph.get_tensor(arg1); - vTensorPtr t_in2 = graph.get_tensor(arg2); + vTensorPtr t_in1 = graph.get_tensor(in1); + vTensorPtr t_in2 = graph.get_tensor(in2); vTensorPtr t_out = graph.get_tensor(out); check_binary_op_args(*t_in1, *t_in2, *t_out); @@ -81,7 +81,7 @@ void add_binary_op_node( graph.create_local_wg_size(out), // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, - {{arg1, arg2}, vkapi::MemoryAccessType::READ}}, + {{in1, in2}, vkapi::MemoryAccessType::READ}}, // Shader params buffers {t_out->sizes_ubo(), t_out->axis_map_ubo(), diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 50fd63445d9..4231dc8bc2d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include @@ -204,4 +206,12 @@ ValueRef prepack_direct_copy_buffer( return tensor; } +void prepack_op(ComputeGraph& graph, const std::vector& args) { + return add_standard_prepack_node(graph, args[0], args[1]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.prepack.default, prepack_op); +} + } // namespace vkcompute diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 7ccfa89e8e7..f9820f825e1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -251,11 +251,11 @@ def __init__(self): self.weight = torch.rand(size=(2, 3), dtype=torch.float32) def forward(self, x, y): - z = torch.add(x, y, alpha=2) - z = torch.add(x, y, alpha=3.14) - z = z + x - z = z + self.weight - return z + inter1 = torch.add(x, y, alpha=2) + inter2 = torch.add(x, y, alpha=3.14) + inter3 = inter1 * self.weight + inter4 = inter2 * self.weight + return inter4 - inter3 internal_data_module = InternalDataModule() sample_inputs = ( diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index ed566a30ccc..777a56e3644 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -19,6 +19,7 @@ from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform +from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( @@ -86,6 +87,8 @@ def preprocess( # noqa: C901 _copy_module(program.graph_module, new_gm) + program = insert_prepack_nodes(program) + graph_builder = VkGraphBuilder( program, DelegateMappingBuilder(generated_identifiers=True) )