From bc935a8721f089801cf2ad100f2312065f65f2ce Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Jun 2025 21:19:17 -0700 Subject: [PATCH] [ET-VK] Migrate ops to use `DynamicDispatchNode` Pull Request resolved: https://github.com/pytorch/executorch/pull/11312 ## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. ghstack-source-id: 288057588 Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) --- .../vulkan/runtime/graph/ComputeGraph.cpp | 9 + backends/vulkan/runtime/graph/ComputeGraph.h | 7 + .../runtime/graph/ops/DynamicDispatchNode.cpp | 5 +- .../runtime/graph/ops/impl/BinaryOp.cpp | 17 +- .../vulkan/runtime/graph/ops/impl/Clone.cpp | 36 ++-- .../vulkan/runtime/graph/ops/impl/Common.cpp | 7 +- .../vulkan/runtime/graph/ops/impl/Common.h | 12 +- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 162 +++++++++++------- .../graph/ops/impl/QuantizedLinearQGANW.cpp | 112 ++++++++---- .../vulkan/runtime/graph/ops/impl/Reduce.cpp | 101 +++++++---- .../graph/ops/impl/RotaryEmbedding.cpp | 30 +++- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 70 +++++--- .../vulkan/runtime/graph/ops/impl/Softmax.cpp | 96 ++++++++--- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 65 +++++-- .../vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 10 +- .../vulkan/runtime/graph/ops/impl/View.cpp | 7 +- backends/vulkan/runtime/utils/VecUtils.h | 10 +- backends/vulkan/test/utils/test_utils.cpp | 9 +- backends/vulkan/test/utils/test_utils.h | 4 +- .../vulkan/test/vulkan_compute_api_test.cpp | 16 +- 20 files changed, 521 insertions(+), 264 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 1222a9fc641..68935e63123 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) { return idx; } +ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) { + for (int i = 0; i < values_.size(); ++i) { + if (values_.at(i).isInt() && values_.at(i).toInt() == val) { + return i; + } + } + return add_scalar(val); +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index fe546f26477..0e8a5eba51f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -604,6 +604,13 @@ class ComputeGraph final { ValueRef add_symint(const int32_t val); + /* + * Searches the graph's value list for a Int value with the specified value. + * If one is found, returns the index of the value. Otherwise, add a new value + * and return the index of the new value. + */ + ValueRef get_or_add_value_for_int(const int64_t val); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index a8d2fe2e99d..b8c0fcbbf79 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode( const ResizeFunction& resize_fn) : DispatchNode( graph, - vkapi::ShaderInfo(), - {1u, 1u, 1u}, + pick_shader_fn(&graph, args, resize_args), {1u, 1u, 1u}, + {8u, 8u, 1u}, args, params, push_constants, @@ -37,7 +37,6 @@ DynamicDispatchNode::DynamicDispatchNode( pick_shader_fn_(pick_shader_fn), pick_global_wg_fn_(pick_global_wg_fn), pick_local_wg_fn_(pick_local_wg_fn) { - shader_ = pick_shader_fn(&graph, args, resize_args); global_workgroup_size_ = pick_global_wg_fn(&graph, shader_, args, resize_args); local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index ff6b54c5289..d260ed767d0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -30,8 +31,8 @@ void check_binary_op_args( void resize_binary_op_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { + (void)resize_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); // TODO(T183442143): Verify tensors are broadcastable. @@ -78,11 +79,11 @@ void add_binary_op_texture_node( add_storage_type_suffix(kernel_name, *t_out); add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}}, // Shader params buffers @@ -122,11 +123,11 @@ void add_binary_op_buffer_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index d0276b1783b..da06223cd12 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -10,6 +10,7 @@ #include +#include #include #include @@ -21,8 +22,8 @@ namespace vkcompute { void resize_clone_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { + (void)resize_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr in = graph->get_tensor(args[1].refs[0]); // TODO: support for when dimensionality doesn't match, i.e. clone is used to @@ -41,11 +42,11 @@ void add_clone_node( std::string kernel_name = "clone"; add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter Buffers @@ -60,6 +61,17 @@ void add_clone_node( resize_clone_node)); } +utils::uvec3 clone_image_to_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef image = args.at(1).refs.at(0); + return graph->create_global_wg_size(image); +} + void add_image_to_buffer_node( ComputeGraph& graph, const ValueRef image, @@ -68,12 +80,11 @@ void add_image_to_buffer_node( add_dtype_suffix(kernel_name, graph.dtype_of(image)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - utils::uvec3 global_wg_size = graph.create_global_wg_size(image); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + clone_image_to_buffer_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{buffer, vkapi::kWrite}, {image, vkapi::kRead}}, // Parameter Buffers @@ -96,12 +107,11 @@ void add_buffer_to_image_node( add_dtype_suffix(kernel_name, graph.dtype_of(image)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - utils::uvec3 global_wg_size = graph.create_global_wg_size(image); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + default_pick_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{image, vkapi::kWrite}, {buffer, vkapi::kRead}}, // Parameter Buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 4de099231d3..4c3c16417b5 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -14,8 +14,9 @@ utils::uvec3 default_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { (void)shader; + (void)resize_args; const ValueRef out = args.at(0).refs.at(0); return graph->create_global_wg_size(out); } @@ -25,8 +26,10 @@ utils::uvec3 default_pick_local_wg_size( const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { (void)shader; + (void)args; + (void)resize_args; return graph->create_local_wg_size(global_workgroup_size); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index d5ff455ae41..662fb07095a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -17,31 +17,23 @@ namespace vkcompute { * Creates a global workgroup size based on the first output tensor in the args. * This is a utility function that extracts the output tensor from * args.at(0).refs.at(0) and calls graph->create_global_wg_size(out) on it. - * - * @param graph The ComputeGraph instance - * @param args Vector of ArgGroup containing the output tensor reference - * @return utils::uvec3 The global workgroup size */ utils::uvec3 default_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, - const std::vector& additional_args); + const std::vector& resize_args); /** * Creates a local workgroup size based on the first output tensor in the args. * This is a utility function that extracts the output tensor from * args.at(0).refs.at(0) and calls graph->create_local_wg_size(out) on it. - * - * @param graph The ComputeGraph instance - * @param args Vector of ArgGroup containing the output tensor reference - * @return utils::uvec3 The local workgroup size */ utils::uvec3 default_pick_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, const std::vector& args, - const std::vector& additional_args); + const std::vector& resize_args); } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 724f4630264..73a625f3adf 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -37,12 +38,12 @@ void check_matmul_args( void resize_matmul_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { + const std::vector& resize_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - bool mat2_is_transposed = graph->get_bool(extra_args[0]); + bool mat2_is_transposed = graph->get_bool(resize_args[0]); const int out_cols = utils::val_at(-2, mat1->sizes()); const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes()) @@ -56,6 +57,23 @@ void resize_matmul_node( out->virtual_resize(new_out_sizes); } +/** + * Custom global workgroup size function for naive buffer matmul operations. + */ +utils::uvec3 matmul_naive_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + void add_matmul_naive_buffer_node( ComputeGraph& graph, const ValueRef mat1, @@ -72,21 +90,16 @@ void add_matmul_naive_buffer_node( std::string kernel_name = "matmul_naive_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_size = { - graph.size_at(-1, out), - graph.size_at(-2, out), - graph.size_at(-3, out) * graph.size_at(-4, out)}; - int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && graph.get_bool(mat2_is_transposed)) ? 1 : 0; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - graph.create_local_wg_size(global_size), + matmul_naive_buffer_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -109,6 +122,22 @@ void add_matmul_naive_buffer_node( resize_matmul_node)); } +vkapi::ShaderInfo pick_matmul_naive_texture3d_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const bool is_transposed = graph->get_bool(resize_args.at(0)); + + std::string kernel_name = + is_transposed ? "matmul_transposed_naive" : "matmul_naive"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + void add_matmul_naive_texture3d_node( ComputeGraph& graph, const ValueRef mat1, @@ -122,19 +151,11 @@ void add_matmul_naive_texture3d_node( utils::kHeightPacked, /*passthrough = */ true); - std::string kernel_name = graph.get_bool(mat2_is_transposed) - ? "matmul_transposed_naive" - : "matmul_naive"; - kernel_name.reserve(kShaderNameReserve); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), + pick_matmul_naive_texture3d_shader, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -156,6 +177,59 @@ void add_matmul_naive_texture3d_node( resize_matmul_node)); } +vkapi::ShaderInfo pick_matmul_optimized_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1_W_packed = resize_args.at(1); + const bool mat2_is_transposed_val = graph->get_bool(resize_args.at(0)); + + std::string kernel_name = mat2_is_transposed_val + ? "matmul_transposed_optimized" + : "matmul_optimized"; + + std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); + size_t mat1_dims = mat1_sizes.size(); + if (mat1_dims == 3) { + kernel_name = "batch_" + kernel_name; + } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } + + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 matmul_optimized_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1_W_packed = resize_args.at(1); + + const std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); + const size_t mat1_dims = mat1_sizes.size(); + + utils::uvec3 global_size = graph->logical_limits_of(out); + if (mat1_sizes.at(mat1_dims - 2) < 8) { + // Use `logical_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(global_size, {4, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {4, 4, 1}); + } + + return global_size; +} + void add_matmul_optimized_node( ComputeGraph& graph, const ValueRef mat1, @@ -192,45 +266,11 @@ void add_matmul_optimized_node( viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } - std::string kernel_name = mat2_is_transposed_val - ? "matmul_transposed_optimized" - : "matmul_optimized"; - - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; - } else { - kernel_name += "_tile_row_4"; - } - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the - // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is - // channels packed, C does not need to be divided by 4. The "identity" of each - // thread is the (x, y, z) coordinate of the output tile it is computing, and - // this identity can be used to compute the tensor index of the top left - // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] - utils::uvec3 global_size = graph.logical_limits_of(out); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - // Use `logical_extents` instead of `image_extents` because the workgroup - // axes need to correspond to tensor dimensions. - global_size = utils::divup_vec(global_size, {4, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); - } - - utils::uvec3 local_size = adaptive_work_group_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + pick_matmul_optimized_shader, + matmul_optimized_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, // Shader params buffers @@ -246,7 +286,7 @@ void add_matmul_optimized_node( graph.hashed_layout_of(mat1_W_packed), graph.hashed_layout_of(mat2_packed)}, // Resize Args - {mat2_is_transposed}, + {mat2_is_transposed, mat1_W_packed}, // Resizing Logic resize_matmul_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 8c5cb0093d9..d9425b8b62f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -70,30 +71,26 @@ void resize_linear_qga4w_node( out->virtual_resize(new_out_sizes); } -void add_linear_qga4w_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - check_linear_qga4w_args( - graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - - const uint32_t group_size_val = graph.extract_scalar(group_size); +/** + * Determines if the cooperative algorithm should be used based on input tensor + * dimensions. Apply the coop algorithm for gemv cases, i.e. mat1 is avector as + * as opposed to a matrix. + */ +bool should_use_coop_algorithm(ComputeGraph* graph, const ValueRef& mat1) { + return graph->size_at(-2, mat1) == 1; +} - bool use_coop_algorithm = false; - // Apply the coop algorithm for gemv cases, i.e. mat1 is a vector as opposed - // to a matrix. - if (graph.size_at(-2, mat1) == 1) { - use_coop_algorithm = true; - } +vkapi::ShaderInfo pick_linear_qga4w_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; - ValueRef mat2 = - prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); - ValueRef scales_and_zeros = prepack_standard_hw_transposed( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); + const bool use_coop_algorithm = should_use_coop_algorithm(graph, mat1); std::string kernel_name = "linear_qga4w"; if (use_coop_algorithm) { @@ -101,26 +98,75 @@ void add_linear_qga4w_node( } else { kernel_name += "_tiled"; } - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(mat2)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 linear_qga4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); + + if (!use_coop_algorithm) { + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + } + + return global_wg_size; +} + +utils::uvec3 linear_qga4w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + (void)resize_args; + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; if (use_coop_algorithm) { - local_wg_size = {8, 1, 8}; + return {8, 1, 8}; } else { - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + return graph->create_local_wg_size(global_workgroup_size); } +} + +void add_linear_qga4w_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef group_size, + const ValueRef scales_and_zeros_data, + const ValueRef out) { + check_linear_qga4w_args( + graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - graph.execute_nodes().emplace_back(new DispatchNode( + const uint32_t group_size_val = graph.extract_scalar(group_size); + + ValueRef mat2 = + prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); + + ValueRef scales_and_zeros = prepack_standard_hw_transposed( + graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + pick_linear_qga4w_shader, + linear_qga4w_global_wg_size, + linear_qga4w_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, scales_and_zeros}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index 8fcd4a0609c..c0fd442ec50 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -20,23 +21,66 @@ using namespace utils; void resize_reduce_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr in = graph->get_tensor(args[1].refs[0]); - int dim = extra_args[0]; + int32_t reduce_dim_nchw = graph->extract_scalar(resize_args.at(0)); std::vector new_sizes = in->sizes(); - new_sizes[normalize(dim, new_sizes.size())] = 1; + new_sizes.at(normalize(reduce_dim_nchw, new_sizes.size())) = 1; out->virtual_resize(new_sizes); } +utils::uvec3 reduce_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + const int32_t reduce_dim_whcn = + graph->extract_scalar(resize_args.at(1)); + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[reduce_dim_whcn] = 1; + return global_wg_size; +} + +utils::uvec3 reduce_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)args; + (void)global_workgroup_size; + + const int32_t reduce_dim_whcn = + graph->extract_scalar(resize_args.at(1)); + const int64_t group_dim_whcn = + graph->extract_scalar(resize_args.at(2)); + + // This should match the value of MAX_NTHREADS in the reduce shader. + constexpr uint32_t max_nthreads = 16; + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim_whcn] = nworkers_per_group; + local_wg_size[group_dim_whcn] = ngroups; + + return local_wg_size; +} + void add_reduce_node( ComputeGraph& graph, - ValueRef in, - const int dim, - ValueRef out, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, const std::string& op_name) { VK_CHECK_COND( !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), @@ -44,7 +88,7 @@ void add_reduce_node( const int64_t ndim = graph.dim_of(in); - int32_t reduce_dim = dim; + int32_t reduce_dim = graph.extract_scalar(dim_ref); reduce_dim = normalize(reduce_dim, ndim); reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); @@ -55,40 +99,30 @@ void add_reduce_node( VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim); } - vkapi::ShaderInfo shader_descriptor; std::string kernel_name = op_name; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - // This should match the value of MAX_NTHREADS in the softmax shader. - constexpr uint32_t max_nthreads = 16; - - const uint32_t nworkers_per_group = 4; - const uint32_t ngroups = 4; - VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[reduce_dim] = 1; - - utils::uvec3 local_wg_size{1, 1, 1}; - local_wg_size[reduce_dim] = nworkers_per_group; + // Calculate group_dim for specialization constants const int other_dim_1 = (reduce_dim + 1) % 3; const int other_dim_2 = (reduce_dim + 2) % 3; int32_t group_dim; - if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { - local_wg_size[other_dim_1] = ngroups; + utils::uvec3 limits = graph.logical_limits_of(out); + if (limits[other_dim_1] > limits[other_dim_2]) { group_dim = other_dim_1; } else { - local_wg_size[other_dim_2] = ngroups; group_dim = other_dim_2; } - graph.execute_nodes().emplace_back(new DispatchNode( + const ValueRef reduce_dim_whcn_ref = + graph.get_or_add_value_for_int(reduce_dim); + const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + reduce_global_wg_size, + reduce_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -98,17 +132,20 @@ void add_reduce_node( // Specialization Constants {graph.packed_dim_of(out), reduce_dim, group_dim}, // Resize Args - {dim}, + {dim_ref, reduce_dim_whcn_ref, group_dim_whcn_ref}, // Resizing Logic resize_reduce_node)); } #define DEFINE_REDUCE_FN(op_name, out_arg_idx) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ - const IntListPtr dims_list = graph.get_int_list(args[1]); \ - VK_CHECK_COND(dims_list->size() == 1); \ + const std::vector dims_list = \ + graph.extract_int_or_symint_list(args[1]); \ + VK_CHECK_COND(dims_list.size() == 1); \ + const int64_t dim_val = dims_list.at(0); \ + const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ return add_reduce_node( \ - graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \ + graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ } DEFINE_REDUCE_FN(sum, 4) diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index 31bab144d8a..fcc8fe4b265 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -8,6 +8,8 @@ #include +#include + #include namespace vkcompute { @@ -31,6 +33,22 @@ void resize_rotary_embedding_node( graph->virtual_resize(xk_out, xk_sizes); } +utils::uvec3 rotary_embedding_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef xq_out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); + global_wg_size[0] /= 2; + + return global_wg_size; +} + void add_rotary_embedding_node( ComputeGraph& graph, const ValueRef xq, @@ -57,17 +75,11 @@ void add_rotary_embedding_node( std::string kernel_name = "rotary_embedding"; add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); - utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out); - global_wg_size[0] /= 2; - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + rotary_embedding_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 5ef84347181..5ac8077d95f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -22,6 +23,24 @@ namespace vkcompute { +utils::uvec3 kv_cache_update_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef cache = args.at(0).refs.at(0); + const ValueRef projected = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(cache)) { + return graph->create_global_wg_size(projected); + } else { + return graph->logical_limits_of(projected); + } +} + void add_kv_cache_update_node( ComputeGraph& graph, const ValueRef input_pos_symint, @@ -31,30 +50,24 @@ void add_kv_cache_update_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); - utils::uvec3 global_size; vkapi::ParamsBindList param_ubos; if (graph.is_buffer_storage(cache)) { - global_size = graph.create_global_wg_size(projected); - param_ubos = { graph.numel_ubo(projected), graph.strides_ubo(cache), graph.get_or_create_int_param_buffer(input_pos_symint)}; } else { - global_size = graph.logical_limits_of(projected); - param_ubos = { graph.logical_limits_ubo(projected), graph.get_or_create_int_param_buffer(input_pos_symint)}; } - const utils::uvec3 local_size = graph.create_local_wg_size(global_size); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + kv_cache_update_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{cache, vkapi::kWrite}, {projected, vkapi::kRead}}, // Shader param buffers @@ -69,6 +82,27 @@ void add_kv_cache_update_node( nullptr)); } +utils::uvec3 attn_weight_scale_and_mask_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef attn_weight = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(attn_weight)) { + return { + graph->size_at(-1, attn_weight), + graph->size_at(-2, attn_weight), + graph->size_at(-3, attn_weight), + }; + } else { + return graph->logical_limits_of(attn_weight); + } +} + void add_attn_weight_scale_and_mask_node( ComputeGraph& graph, const ValueRef input_pos_symint, @@ -81,37 +115,25 @@ void add_attn_weight_scale_and_mask_node( const int32_t head_dim_size = graph.size_at(-1, q_projected); const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - utils::uvec3 global_size; - utils::uvec3 local_size; vkapi::ParamsBindList param_ubos; if (graph.is_buffer_storage(attn_weight)) { - global_size = { - graph.size_at(-1, attn_weight), - graph.size_at(-2, attn_weight), - graph.size_at(-3, attn_weight), - }; - param_ubos = { graph.sizes_ubo(attn_weight), graph.strides_ubo(attn_weight), graph.create_params_buffer(scale_val)}; } else { - global_size = graph.logical_limits_of(attn_weight); - param_ubos = { graph.logical_limits_ubo(attn_weight), graph.get_or_create_int_param_buffer(input_pos_symint), graph.create_params_buffer(scale_val)}; } - local_size = graph.create_local_wg_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + attn_weight_scale_and_mask_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{attn_weight, vkapi::kReadWrite}}, // Shader param buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp index 7469cbb0eb2..e37ef66434b 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -17,11 +18,55 @@ namespace vkcompute { using namespace utils; +utils::uvec3 pick_softmax_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const int32_t reduce_dim_xyz = + graph->extract_scalar(resize_args.at(1)); + + utils::uvec3 global_size = graph->logical_limits_of(out); + global_size[reduce_dim_xyz] = 1; + return global_size; +} + +utils::uvec3 pick_softmax_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)global_workgroup_size; + (void)args; + + const int64_t group_dim_xyz = + graph->extract_scalar(resize_args.at(2)); + + const int32_t reduce_dim_xyz = + graph->extract_scalar(resize_args.at(1)); + + // These values are hardcoded in add_softmax_node + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim_xyz] = nworkers_per_group; + local_wg_size[group_dim_xyz] = ngroups; + + return local_wg_size; +} + void resize_softmax_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { + (void)resize_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr in = graph->get_tensor(args[1].refs[0]); @@ -31,9 +76,9 @@ void resize_softmax_node( void add_softmax_node( ComputeGraph& graph, - ValueRef in, - ValueRef dim, - ValueRef out, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, bool log_softmax) { VK_CHECK_COND( !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), @@ -41,18 +86,18 @@ void add_softmax_node( const int64_t ndim = graph.dim_of(in); - int32_t reduce_dim = graph.extract_scalar(dim); - reduce_dim = normalize(reduce_dim, ndim); - reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + int32_t reduce_dim_nchw = graph.extract_scalar(dim_ref); + reduce_dim_nchw = normalize(reduce_dim_nchw, ndim); + const int32_t reduce_dim_xyz = nchw_dim_to_whcn_dim(reduce_dim_nchw, ndim); // Check that the concat dim is not the reduction dim, if the tensor has a // batch dim greater than 1. if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { VK_CHECK_COND( - graph.concat_dim_of(in) != reduce_dim, + graph.concat_dim_of(in) != reduce_dim_xyz, "Softmax shader currently does not support concat dim == reduce dim"); VK_CHECK_COND( - graph.concat_dim_of(out) != reduce_dim, + graph.concat_dim_of(out) != reduce_dim_xyz, "Softmax shader currently does not support concat dim == reduce dim"); } @@ -71,39 +116,36 @@ void add_softmax_node( const uint32_t ngroups = 4; VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[reduce_dim] = 1; - - utils::uvec3 local_wg_size{1, 1, 1}; - local_wg_size[reduce_dim] = nworkers_per_group; - const int other_dim_1 = (reduce_dim + 1) % 3; - const int other_dim_2 = (reduce_dim + 2) % 3; + // Determine the group dimension + const int other_dim_1 = (reduce_dim_xyz + 1) % 3; + const int other_dim_2 = (reduce_dim_xyz + 2) % 3; int32_t group_dim; + utils::uvec3 global_wg_size = graph.logical_limits_of(out); if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { - local_wg_size[other_dim_1] = ngroups; group_dim = other_dim_1; } else { - local_wg_size[other_dim_2] = ngroups; group_dim = other_dim_2; } - graph.execute_nodes().emplace_back(new DispatchNode( + const ValueRef reduce_dim_xyz_ref = + graph.get_or_add_value_for_int(reduce_dim_xyz); + const ValueRef group_dim_xyz_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + pick_softmax_global_wg_size, + pick_softmax_local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers {graph.logical_limits_ubo(out), graph.sizes_ubo(in)}, // Push Constants {}, // Specialization Constants - {graph.packed_dim_of(out), reduce_dim, group_dim}, + {graph.packed_dim_of(out), reduce_dim_xyz, group_dim}, // Resize Args - {}, + {dim_ref, reduce_dim_xyz_ref, group_dim_xyz_ref}, // Resizing Logic resize_softmax_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 8c060a9da4b..4c46596c206 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -10,7 +10,8 @@ #include -#include +#include +#include #include #include @@ -38,11 +39,11 @@ void add_staging_to_tensor_node( pcs = {graph.sizes_pc_of(out_tensor)}; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - graph.create_global_wg_size(out_tensor), - graph.create_local_wg_size(out_tensor), + default_pick_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{out_tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}}, // Parameter Buffers @@ -65,6 +66,44 @@ bool is_bitw8_shader(const vkapi::ShaderInfo& shader) { return shader_prefix_str == kBitw8PrefixStr; } +vkapi::ShaderInfo get_tensor_to_staging_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef in_tensor = args.at(1).refs.at(0); + return get_tensor_to_nchw_shader( + *graph->get_tensor(in_tensor), graph->int8_buffers_enabled()); +} + +utils::uvec3 tensor_to_staging_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef in_tensor = args.at(1).refs.at(0); + const ValueRef out_staging = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(in_tensor); + + // Normally, the image_to_nchw shader is structured so that each thread reads + // one texel from the input texture and writes each component of the texel + // into the corresponding location in the output buffer. However, this shader + // is structured slightly differently in that each thread writes out a + // complete 32 bit integer (containing 4 packed 8-bit integers) into the + // output buffer. Therefore, the global work group size for this shader will + // be the number of elements in the output buffer divided by 4, as opposed to + // the extents of the input texture. + if (is_bitw8_shader(shader)) { + const uint32_t buffer_len = utils::safe_downcast( + graph->get_staging(out_staging)->numel() / 4); + global_wg_size = {buffer_len, 1, 1}; + } + + return global_wg_size; +} + void add_tensor_to_staging_node( ComputeGraph& graph, const ValueRef in_tensor, @@ -74,8 +113,6 @@ void add_tensor_to_staging_node( vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( *graph.get_tensor(in_tensor), graph.int8_buffers_enabled()); - utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor); - vkapi::ParamsBindList ubos; if (graph.is_buffer_storage(in_tensor)) { ubos.append( @@ -86,25 +123,15 @@ void add_tensor_to_staging_node( ubos.append({graph.sizes_ubo(in_tensor)}); } - // Normally, the image_to_nchw shader is structured so that each thread reads - // one texel from the input texture and writes each component of the texel - // into the corresponding location in the output buffer. However, this shader - // is structured slightly differently in that each thread writes out a - // complete 32 bit integer (containing 4 packed 8-bit integers) into the - // output buffer. Therefore, the global work group size for this shader will - // be the number of elements in the output buffer divided by 4, as opposed to - // the extents of the input texture. if (is_bitw8_shader(shader)) { - uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4; - global_wg_size = {buffer_len, 1, 1}; ubos.append({graph.numel_ubo(in_tensor)}); } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + tensor_to_staging_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{out_staging, vkapi::kWrite}, {in_tensor, vkapi::kRead}}, // Parameter Buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index bffa8e2a181..518148f12eb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -51,14 +52,13 @@ void add_unary_op_node( ubos.append( {graph.create_params_buffer(min), graph.create_params_buffer(max)}); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers ubos, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 710ba0d576f..9dbe79faebb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -67,11 +68,11 @@ void add_view_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}}, diff --git a/backends/vulkan/runtime/utils/VecUtils.h b/backends/vulkan/runtime/utils/VecUtils.h index c084a563544..6d2e8c63bb9 100644 --- a/backends/vulkan/runtime/utils/VecUtils.h +++ b/backends/vulkan/runtime/utils/VecUtils.h @@ -260,12 +260,18 @@ struct vec final { } } - const Type& operator[](const uint32_t& i) const { + template < + typename IndexType, + typename = std::enable_if_t::value>> + const Type& operator[](const IndexType& i) const { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } - Type& operator[](const uint32_t& i) { + template < + typename IndexType, + typename = std::enable_if_t::value>> + Type& operator[](const IndexType& i) { VK_CHECK_COND(i >= 0 && i < N, "Index out of bounds!"); return data[i]; } diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index dcd8c425d62..99ee1c0fa0b 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -547,8 +547,8 @@ vkcompute::ComputeGraph build_mm_graph( vkcompute::vkapi::ScalarType dtype, vkcompute::utils::StorageType in_out_stype, vkcompute::utils::GPUMemoryLayout memory_layout, - const bool prepack_mat2, - const float mat2_val) { + const std::vector& mat2_data, + const bool prepack_mat2) { using namespace vkcompute; GraphConfig config; ComputeGraph graph(config); @@ -569,10 +569,7 @@ vkcompute::ComputeGraph build_mm_graph( graph.add_input_tensor(mat1_size, dtype, in_out_stype, memory_layout); IOValueRef mat2{}; - CREATE_RAND_WEIGHT_TENSOR(mat2_w, mat2_size, dtype); - if (mat2_val != 0.0f) { - std::fill(data_mat2_w.begin(), data_mat2_w.end(), mat2_val); - } + ValueRef mat2_w = graph.add_tensorref(mat2_size, dtype, mat2_data.data()); if (prepack_mat2) { mat2.value = mat2_w; diff --git a/backends/vulkan/test/utils/test_utils.h b/backends/vulkan/test/utils/test_utils.h index 71d6d0bc0de..0f0d2647792 100644 --- a/backends/vulkan/test/utils/test_utils.h +++ b/backends/vulkan/test/utils/test_utils.h @@ -265,8 +265,8 @@ vkcompute::ComputeGraph build_mm_graph( vkcompute::vkapi::ScalarType dtype, vkcompute::utils::StorageType in_out_stype, vkcompute::utils::GPUMemoryLayout memory_layout, - const bool prepack_mat2 = false, - const float mat2_val = 0.0f); + const std::vector& mat2_data, + const bool prepack_mat2 = false); // // Debugging Utilities diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 85811aaaf11..9887f3c7ffb 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -2751,15 +2751,19 @@ void test_mm( utils::StorageType storage_type, utils::GPUMemoryLayout memory_layout, bool prepack = true) { + std::vector mat2_size = {B, K, N}; + + std::vector mat2_data(utils::multiply_integers(mat2_size)); + std::fill(mat2_data.begin(), mat2_data.end(), 2.0f); ComputeGraph graph = build_mm_graph( - B, M, K, N, dtype, storage_type, memory_layout, prepack, 2.0f); + B, M, K, N, dtype, storage_type, memory_layout, mat2_data, prepack); graph.prepare(); graph.encode_prepack(); graph.prepack(); - graph.encode_execute(); for (int i = 1; i < 4; i++) { + graph.encode_execute(); if (prepack) { float val_mat1 = i; float val_out = K * (val_mat1 * 2.0f); @@ -2828,8 +2832,12 @@ void test_mm_with_resize_reencode( utils::GPUMemoryLayout memory_layout) { ASSERT_TRUE(M > 1); + std::vector mat2_size = {B, K, N}; + std::vector mat2_data(utils::multiply_integers(mat2_size)); + std::fill(mat2_data.begin(), mat2_data.end(), 2.0f); + ComputeGraph graph = build_mm_graph( - B, M, K, N, dtype, storage_type, memory_layout, false, 2.0f); + B, M, K, N, dtype, storage_type, memory_layout, mat2_data, false); graph.prepare(); graph.encode_prepack(); @@ -2851,8 +2859,6 @@ void test_mm_with_resize_reencode( graph.resize_input(1, new_mat2_size); graph.propagate_resize(); - graph.encode_execute(); - for (int i = 1; i < 4; i++) { float val_mat1 = i; float val_mat2 = i + 1;