diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 15973cae577..3903fe1bdf4 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -65,7 +65,7 @@ def call(self, graph_module: torch.fx.GraphModule): with graph_module.graph.inserting_before(preceding_op): conv_activation_node = graph_module.graph.create_node( "call_function", - torch.ops.et_vk.conv_with_clamp.default, + exir_ops.edge.et_vk.conv_with_clamp.default, new_args, ) diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 812c39c2b64..fa828640bf4 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -28,6 +28,18 @@ python_unittest( ], ) +runtime.python_library( + name = "insert_prepack_nodes", + srcs = ["insert_prepack_nodes.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + "//caffe2:torch", + "//executorch/exir:pass_base", + ], +) + runtime.python_library( name = "remove_local_scalar_dense", srcs = ["remove_local_scalar_dense_ops.py"], @@ -65,6 +77,7 @@ runtime.python_library( "//executorch/examples/...", ], deps = [ + ":insert_prepack_nodes", ":int4_weight_only_quantizer", ":remove_local_scalar_dense", ] diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 080df836080..cfdb7c6eeee 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -1,3 +1,4 @@ +from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( VkInt4WeightOnlyQuantizer, ) @@ -6,6 +7,7 @@ ) __all__ = [ + "insert_prepack_nodes", "VkInt4WeightOnlyQuantizer", "RemoveLocalScalarDenseOpsTransform", ] diff --git a/backends/vulkan/_passes/custom_ops_defs.py b/backends/vulkan/_passes/custom_ops_defs.py index 2c16a331c04..4da2a31fc44 100644 --- a/backends/vulkan/_passes/custom_ops_defs.py +++ b/backends/vulkan/_passes/custom_ops_defs.py @@ -9,6 +9,20 @@ namespace = "et_vk" lib = torch.library.Library(namespace, "DEF") +############# +## prepack ## +############# + + +def prepack_impl(x: torch.Tensor): + return x + + +name = "prepack" +lib.define(f"{name}(Tensor x) -> Tensor") +lib.impl(name, prepack_impl, "CompositeExplicitAutograd") +prepack_op = getattr(getattr(torch.ops, namespace), name) + ##################### ## conv_with_clamp ## ##################### diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py new file mode 100644 index 00000000000..afedf7af694 --- /dev/null +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List + +import executorch.backends.vulkan._passes.custom_ops_defs # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops + +from torch._export.utils import is_buffer, is_param +from torch.export import ExportedProgram + +USES_WEIGHTS: List[torch._ops.OpOverload] = [ + exir_ops.edge.aten.embedding.default, + exir_ops.edge.aten.convolution.default, + exir_ops.edge.et_vk.conv_with_clamp.default, + exir_ops.edge.aten.linear.default, + exir_ops.edge.aten._weight_int8pack_mm.default, + exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten._native_batch_norm_legit_no_training.default, + exir_ops.edge.aten.native_layer_norm.default, + "llama::sdpa_with_kv_cache", +] + + +def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: + """ + Insert `et_vk.prepack` nodes for constant tensors in the graph. The prepack operator + is responsible for transferring the tensor data, which is serialized with the model, + to a GPU tensor object during the prepacking stage of model execution. + + Some operators, listed in `USES_WEIGHTS` above, are performance sensitive and will + prefer to handle prepacking within the operator. For these ops, the constant tensor + data will be passed directly as an argument into the operator implementation. + """ + + def is_get_attr_node(node: torch.fx.Node) -> bool: + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + def is_constant(node: torch.fx.Node) -> bool: + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + def is_param_node(node: torch.fx.Node) -> bool: + """ + Check if the given node is a parameter within the exported program + """ + return ( + is_get_attr_node(node) + or is_param(program, node) + or is_buffer(program, node) + or is_constant(node) + ) + + def is_non_weight_param_tensor(node: torch.fx.Node) -> bool: + if not is_param_node(node): + return False + + for user in node.users: + if user.op == "call_function" and ( + # pyre-ignore [16] + user.target in USES_WEIGHTS + or user.target.name() in USES_WEIGHTS + ): + return False + + return True + + for node in program.graph_module.graph.nodes: + if not is_non_weight_param_tensor(node): + continue + + with program.graph_module.graph.inserting_after(node): + prepack_node = program.graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.prepack.default, + (node,), + ) + prepack_node.meta["spec"] = node.meta["spec"] + # Set the mem_obj_id to -1 to indicate that this node requires a dedicated + # memory object. This pass must be executed AFTER the memory planning pass. + prepack_node.meta["spec"].mem_obj_id = -1 + node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) + + program.graph.eliminate_dead_code() + return program diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index ac7b223eff8..15045ccca27 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -6,6 +6,8 @@ * LICENSE file in the root directory of this source tree. */ +#include + #include #include @@ -205,4 +207,12 @@ ValueRef prepack_direct_copy_buffer( return tensor; } +void prepack_op(ComputeGraph& graph, const std::vector& args) { + return add_prepack_standard_node(graph, args[0], args[1]); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(et_vk.prepack.default, prepack_op); +} + } // namespace vkcompute diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 7ccfa89e8e7..f9820f825e1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -251,11 +251,11 @@ def __init__(self): self.weight = torch.rand(size=(2, 3), dtype=torch.float32) def forward(self, x, y): - z = torch.add(x, y, alpha=2) - z = torch.add(x, y, alpha=3.14) - z = z + x - z = z + self.weight - return z + inter1 = torch.add(x, y, alpha=2) + inter2 = torch.add(x, y, alpha=3.14) + inter3 = inter1 * self.weight + inter4 = inter2 * self.weight + return inter4 - inter3 internal_data_module = InternalDataModule() sample_inputs = ( diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 694eeebecee..87cafd10a7e 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -1520,11 +1520,18 @@ TEST(VulkanComputeGraphTest, test_simple_prepacked_graph) { ValueRef c = graph.add_tensor(size_big, vkapi::kFloat); ValueRef e = graph.add_tensor(size_big, vkapi::kFloat); + ValueRef w1_packed = graph.add_tensor(size_small, vkapi::kFloat); + ValueRef w2_packed = graph.add_tensor(size_small, vkapi::kFloat); + + auto prepackFn = VK_GET_OP_FN("et_vk.prepack.default"); + prepackFn(graph, {w1, w1_packed}); + prepackFn(graph, {w2, w2_packed}); + auto addFn = VK_GET_OP_FN("aten.add.Tensor"); - addFn(graph, {a.value, w1, kDummyValueRef, c}); + addFn(graph, {a.value, w1_packed, kDummyValueRef, c}); auto mulFn = VK_GET_OP_FN("aten.mul.Tensor"); - mulFn(graph, {c, w2, e}); + mulFn(graph, {c, w2_packed, e}); IOValueRef out = {}; out.value = e; @@ -2597,24 +2604,16 @@ void test_binary_op( std::vector sizes_big, std::vector sizes_small, vkapi::ScalarType dtype, - utils::GPUMemoryLayout memory_layout, - bool prepack = true) { + utils::GPUMemoryLayout memory_layout) { GraphConfig config; ComputeGraph graph(config); IOValueRef arg2{}; - CREATE_WEIGHT_TENSOR(arg2_w, sizes_small, dtype, 2.5f); - // Build graph IOValueRef arg1 = graph.add_input_tensor(sizes_big, dtype, memory_layout); - - if (prepack) { - arg2.value = arg2_w; - } else { - arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout); - } + arg2 = graph.add_input_tensor(sizes_small, dtype, memory_layout); IOValueRef out; out.value = graph.add_tensor(sizes_big, dtype, memory_layout); @@ -2635,7 +2634,7 @@ void test_binary_op( for (int i = 1; i < 4; i++) { float val_arg1 = i + 1.5; - float val_arg2 = prepack ? 2.5f : i - 3.5; + float val_arg2 = i - 3.5; float val_out = val_arg1 + val_arg2; if (op_name == "sub") { @@ -2648,21 +2647,14 @@ void test_binary_op( val_out = val_arg1 / val_arg2; } - if (prepack) { - execute_graph_and_check_output(graph, {val_arg1}, {val_out}); - } else { - execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out}); - } + execute_graph_and_check_output(graph, {val_arg1, val_arg2}, {val_out}); } } -#define CALL_TEST_FN_FORALL_CONDITIONS(_) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \ - _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true) +#define CALL_TEST_FN_FORALL_CONDITIONS(_) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked) #define CALL_TEST_FN_FOR_W_PACKED(_) \ _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \ @@ -2677,15 +2669,15 @@ void test_binary_op( _(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true) TEST(VulkanComputeGraphOpsTest, add_smoke_test) { -#define RUN_TESTS(dtype, storage, layout, prepack) \ - test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \ - test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \ - test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \ - test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout, prepack); \ - test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout, prepack); \ - test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout, prepack); \ - test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout, prepack); \ - test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout, prepack); +#define RUN_TESTS(dtype, storage, layout) \ + test_binary_op("add", {17, 21}, {17, 21}, dtype, layout); \ + test_binary_op("add", {17, 21}, {1, 1}, dtype, layout); \ + test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout); \ + test_binary_op("sub", {11, 22}, {11, 1}, dtype, layout); \ + test_binary_op("add", {7, 17, 17}, {7, 17, 17}, dtype, layout); \ + test_binary_op("add", {7, 17, 17}, {7, 1, 17}, dtype, layout); \ + test_binary_op("sub", {9, 9, 7}, {9, 9, 7}, dtype, layout); \ + test_binary_op("sub", {9, 9, 7}, {9, 1, 1}, dtype, layout); CALL_TEST_FN_FORALL_CONDITIONS(RUN_TESTS); diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index ed566a30ccc..777a56e3644 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -19,6 +19,7 @@ from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform +from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( @@ -86,6 +87,8 @@ def preprocess( # noqa: C901 _copy_module(program.graph_module, new_gm) + program = insert_prepack_nodes(program) + graph_builder = VkGraphBuilder( program, DelegateMappingBuilder(generated_identifiers=True) )