From 891ebebb8944f75375206386f1b69e2c43f62d9b Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Tue, 3 Jun 2025 09:27:57 -0700 Subject: [PATCH] [ET-VK] Migrate ops to use `DynamicDispatchNode` ## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ComputeGraph.cpp | 9 + backends/vulkan/runtime/graph/ComputeGraph.h | 7 + .../runtime/graph/ops/impl/BinaryOp.cpp | 17 +- .../vulkan/runtime/graph/ops/impl/Clone.cpp | 25 ++- .../vulkan/runtime/graph/ops/impl/MatMul.cpp | 161 +++++++++++------- .../graph/ops/impl/QuantizedLinearQGANW.cpp | 107 ++++++++---- .../vulkan/runtime/graph/ops/impl/Reduce.cpp | 95 +++++++---- .../graph/ops/impl/RotaryEmbedding.cpp | 29 +++- .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 68 +++++--- .../vulkan/runtime/graph/ops/impl/Softmax.cpp | 89 +++++++--- .../vulkan/runtime/graph/ops/impl/Staging.cpp | 62 ++++--- .../vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 10 +- .../vulkan/runtime/graph/ops/impl/View.cpp | 7 +- 13 files changed, 456 insertions(+), 230 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 1222a9fc641..68935e63123 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -449,6 +449,15 @@ ValueRef ComputeGraph::add_symint(const int32_t val) { return idx; } +ValueRef ComputeGraph::get_or_add_value_for_int(const int64_t val) { + for (int i = 0; i < values_.size(); ++i) { + if (values_.at(i).isInt() && values_.at(i).toInt() == val) { + return i; + } + } + return add_scalar(val); +} + ValueRef ComputeGraph::set_input_tensor( const ValueRef idx, const bool use_staging) { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index fe546f26477..0e8a5eba51f 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -604,6 +604,13 @@ class ComputeGraph final { ValueRef add_symint(const int32_t val); + /* + * Searches the graph's value list for a Int value with the specified value. + * If one is found, returns the index of the value. Otherwise, add a new value + * and return the index of the new value. + */ + ValueRef get_or_add_value_for_int(const int64_t val); + ValueRef set_input_tensor(const ValueRef idx, const bool use_staging = true); ValueRef set_output_tensor(const ValueRef idx, const bool use_staging = true); diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index ff6b54c5289..d260ed767d0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -30,8 +31,8 @@ void check_binary_op_args( void resize_binary_op_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { + (void)resize_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); // TODO(T183442143): Verify tensors are broadcastable. @@ -78,11 +79,11 @@ void add_binary_op_texture_node( add_storage_type_suffix(kernel_name, *t_out); add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{arg1, arg2}, vkapi::kRead}}, // Shader params buffers @@ -122,11 +123,11 @@ void add_binary_op_buffer_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{in1, in2}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp index d0276b1783b..2b9bd08eb12 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Clone.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Clone.cpp @@ -10,6 +10,7 @@ #include +#include #include #include @@ -21,8 +22,8 @@ namespace vkcompute { void resize_clone_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { + (void)resize_args; vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr in = graph->get_tensor(args[1].refs[0]); // TODO: support for when dimensionality doesn't match, i.e. clone is used to @@ -41,11 +42,11 @@ void add_clone_node( std::string kernel_name = "clone"; add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Parameter Buffers @@ -68,12 +69,11 @@ void add_image_to_buffer_node( add_dtype_suffix(kernel_name, graph.dtype_of(image)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - utils::uvec3 global_wg_size = graph.create_global_wg_size(image); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + default_pick_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{buffer, vkapi::kWrite}, {image, vkapi::kRead}}, // Parameter Buffers @@ -96,12 +96,11 @@ void add_buffer_to_image_node( add_dtype_suffix(kernel_name, graph.dtype_of(image)); vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - utils::uvec3 global_wg_size = graph.create_global_wg_size(image); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + default_pick_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{image, vkapi::kWrite}, {buffer, vkapi::kRead}}, // Parameter Buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 724f4630264..7abba06d3b8 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -37,12 +38,12 @@ void check_matmul_args( void resize_matmul_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { + const std::vector& resize_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr mat1 = graph->get_tensor(args[1].refs[0]); vTensorPtr mat2 = graph->get_tensor(args[1].refs[1]); - bool mat2_is_transposed = graph->get_bool(extra_args[0]); + bool mat2_is_transposed = graph->get_bool(resize_args[0]); const int out_cols = utils::val_at(-2, mat1->sizes()); const int out_rows = mat2_is_transposed ? utils::val_at(-2, mat2->sizes()) @@ -56,6 +57,22 @@ void resize_matmul_node( out->virtual_resize(new_out_sizes); } +/** + * Custom global workgroup size function for naive buffer matmul operations. + */ +utils::uvec3 matmul_naive_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; +} + void add_matmul_naive_buffer_node( ComputeGraph& graph, const ValueRef mat1, @@ -72,21 +89,16 @@ void add_matmul_naive_buffer_node( std::string kernel_name = "matmul_naive_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); - utils::uvec3 global_size = { - graph.size_at(-1, out), - graph.size_at(-2, out), - graph.size_at(-3, out) * graph.size_at(-4, out)}; - int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && graph.get_bool(mat2_is_transposed)) ? 1 : 0; - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - graph.create_local_wg_size(global_size), + matmul_naive_buffer_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -109,6 +121,22 @@ void add_matmul_naive_buffer_node( resize_matmul_node)); } +vkapi::ShaderInfo pick_matmul_naive_texture3d_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const bool is_transposed = graph->get_bool(resize_args.at(0)); + + std::string kernel_name = + is_transposed ? "matmul_transposed_naive" : "matmul_naive"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + void add_matmul_naive_texture3d_node( ComputeGraph& graph, const ValueRef mat1, @@ -122,19 +150,11 @@ void add_matmul_naive_texture3d_node( utils::kHeightPacked, /*passthrough = */ true); - std::string kernel_name = graph.get_bool(mat2_is_transposed) - ? "matmul_transposed_naive" - : "matmul_naive"; - kernel_name.reserve(kShaderNameReserve); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), + pick_matmul_naive_texture3d_shader, + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, // Shader params buffers @@ -156,6 +176,59 @@ void add_matmul_naive_texture3d_node( resize_matmul_node)); } +vkapi::ShaderInfo pick_matmul_optimized_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1_W_packed = resize_args.at(1); + const bool mat2_is_transposed_val = graph->get_bool(resize_args.at(0)); + + std::string kernel_name = mat2_is_transposed_val + ? "matmul_transposed_optimized" + : "matmul_optimized"; + + std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); + int mat1_dims = mat1_sizes.size(); + if (mat1_dims == 3) { + kernel_name = "batch_" + kernel_name; + } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } + + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 matmul_optimized_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1_W_packed = resize_args.at(1); + + const std::vector mat1_sizes = graph->sizes_of(mat1_W_packed); + const int mat1_dims = mat1_sizes.size(); + + utils::uvec3 global_size = graph->logical_limits_of(out); + if (mat1_sizes.at(mat1_dims - 2) < 8) { + // Use `logical_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(global_size, {4, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {4, 4, 1}); + } + + return global_size; +} + void add_matmul_optimized_node( ComputeGraph& graph, const ValueRef mat1, @@ -192,45 +265,11 @@ void add_matmul_optimized_node( viewFn(graph, {mat2, graph.add_none(), mat2_packed}); } - std::string kernel_name = mat2_is_transposed_val - ? "matmul_transposed_optimized" - : "matmul_optimized"; - - std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); - int mat1_dims = mat1_sizes.size(); - if (mat1_dims == 3) { - kernel_name = "batch_" + kernel_name; - } - if (mat1_sizes.at(mat1_dims - 2) < 8) { - kernel_name += "_tile_row_2"; - } else { - kernel_name += "_tile_row_4"; - } - - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the - // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is - // channels packed, C does not need to be divided by 4. The "identity" of each - // thread is the (x, y, z) coordinate of the output tile it is computing, and - // this identity can be used to compute the tensor index of the top left - // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] - utils::uvec3 global_size = graph.logical_limits_of(out); - if (mat1_sizes.at(mat1_dims - 2) < 8) { - // Use `logical_extents` instead of `image_extents` because the workgroup - // axes need to correspond to tensor dimensions. - global_size = utils::divup_vec(global_size, {4, 2, 1}); - } else { - global_size = utils::divup_vec(global_size, {4, 4, 1}); - } - - utils::uvec3 local_size = adaptive_work_group_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + pick_matmul_optimized_shader, + matmul_optimized_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}}, // Shader params buffers @@ -246,7 +285,7 @@ void add_matmul_optimized_node( graph.hashed_layout_of(mat1_W_packed), graph.hashed_layout_of(mat2_packed)}, // Resize Args - {mat2_is_transposed}, + {mat2_is_transposed, mat1_W_packed}, // Resizing Logic resize_matmul_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp index 8c5cb0093d9..abb7a882663 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -70,6 +71,75 @@ void resize_linear_qga4w_node( out->virtual_resize(new_out_sizes); } +/** + * Determines if the cooperative algorithm should be used based on input tensor + * dimensions. Apply the coop algorithm for gemv cases, i.e. mat1 is avector as + * as opposed to a matrix. + */ +bool should_use_coop_algorithm(ComputeGraph* graph, const ValueRef& mat1) { + return graph->size_at(-2, mat1) == 1; +} + +vkapi::ShaderInfo pick_linear_qga4w_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef mat1 = args.at(1).refs.at(0); + const ValueRef mat2 = args.at(1).refs.at(1); + + const bool use_coop_algorithm = should_use_coop_algorithm(graph, mat1); + + std::string kernel_name = "linear_qga4w"; + if (use_coop_algorithm) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 linear_qga4w_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef out = args.at(0).refs.at(0); + + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); + + if (!use_coop_algorithm) { + global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); + } + + return global_wg_size; +} + +utils::uvec3 linear_qga4w_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + const bool use_coop_algorithm = + shader.kernel_name.find("_coop") != std::string::npos; + + if (use_coop_algorithm) { + return {8, 1, 8}; + } else { + return graph->create_local_wg_size(global_workgroup_size); + } +} + void add_linear_qga4w_node( ComputeGraph& graph, const ValueRef mat1, @@ -82,45 +152,16 @@ void add_linear_qga4w_node( const uint32_t group_size_val = graph.extract_scalar(group_size); - bool use_coop_algorithm = false; - // Apply the coop algorithm for gemv cases, i.e. mat1 is a vector as opposed - // to a matrix. - if (graph.size_at(-2, mat1) == 1) { - use_coop_algorithm = true; - } - ValueRef mat2 = prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); ValueRef scales_and_zeros = prepack_standard_hw_transposed( graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); - - std::string kernel_name = "linear_qga4w"; - if (use_coop_algorithm) { - kernel_name += "_coop"; - } else { - kernel_name += "_tiled"; - } - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(mat1)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(mat2)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - if (use_coop_algorithm) { - local_wg_size = {8, 1, 8}; - } else { - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); - } - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + pick_linear_qga4w_shader, + linear_qga4w_global_wg_size, + linear_qga4w_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {{mat1, mat2, scales_and_zeros}, vkapi::kRead}}, // Shader params buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp index 8fcd4a0609c..c93c0954bc2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Reduce.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -20,23 +21,65 @@ using namespace utils; void resize_reduce_node( ComputeGraph* graph, const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; + const std::vector& resize_args) { vTensorPtr out = graph->get_tensor(args[0].refs[0]); vTensorPtr in = graph->get_tensor(args[1].refs[0]); - int dim = extra_args[0]; + int32_t reduce_dim_nchw = graph->extract_scalar(resize_args.at(0)); std::vector new_sizes = in->sizes(); - new_sizes[normalize(dim, new_sizes.size())] = 1; + new_sizes.at(normalize(reduce_dim_nchw, new_sizes.size())) = 1; out->virtual_resize(new_sizes); } +utils::uvec3 reduce_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + const int32_t reduce_dim_whcn = + graph->extract_scalar(resize_args.at(1)); + + utils::uvec3 global_wg_size = graph->logical_limits_of(out); + global_wg_size[reduce_dim_whcn] = 1; + return global_wg_size; +} + +utils::uvec3 reduce_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)global_workgroup_size; + + const int32_t reduce_dim_whcn = + graph->extract_scalar(resize_args.at(1)); + const int64_t group_dim_whcn = + graph->extract_scalar(resize_args.at(2)); + + // This should match the value of MAX_NTHREADS in the reduce shader. + constexpr uint32_t max_nthreads = 16; + + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim_whcn] = nworkers_per_group; + local_wg_size[group_dim_whcn] = ngroups; + + return local_wg_size; +} + void add_reduce_node( ComputeGraph& graph, - ValueRef in, - const int dim, - ValueRef out, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, const std::string& op_name) { VK_CHECK_COND( !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), @@ -44,7 +87,7 @@ void add_reduce_node( const int64_t ndim = graph.dim_of(in); - int32_t reduce_dim = dim; + int32_t reduce_dim = graph.extract_scalar(dim_ref); reduce_dim = normalize(reduce_dim, ndim); reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); @@ -55,40 +98,30 @@ void add_reduce_node( VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim); } - vkapi::ShaderInfo shader_descriptor; std::string kernel_name = op_name; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); - // This should match the value of MAX_NTHREADS in the softmax shader. - constexpr uint32_t max_nthreads = 16; - - const uint32_t nworkers_per_group = 4; - const uint32_t ngroups = 4; - VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); - - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[reduce_dim] = 1; - - utils::uvec3 local_wg_size{1, 1, 1}; - local_wg_size[reduce_dim] = nworkers_per_group; + // Calculate group_dim for specialization constants const int other_dim_1 = (reduce_dim + 1) % 3; const int other_dim_2 = (reduce_dim + 2) % 3; int32_t group_dim; - if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { - local_wg_size[other_dim_1] = ngroups; + utils::uvec3 limits = graph.logical_limits_of(out); + if (limits[other_dim_1] > limits[other_dim_2]) { group_dim = other_dim_1; } else { - local_wg_size[other_dim_2] = ngroups; group_dim = other_dim_2; } - graph.execute_nodes().emplace_back(new DispatchNode( + const ValueRef reduce_dim_whcn_ref = + graph.get_or_add_value_for_int(reduce_dim); + const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + reduce_global_wg_size, + reduce_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -98,7 +131,7 @@ void add_reduce_node( // Specialization Constants {graph.packed_dim_of(out), reduce_dim, group_dim}, // Resize Args - {dim}, + {dim_ref, reduce_dim_whcn_ref, group_dim_whcn_ref}, // Resizing Logic resize_reduce_node)); } @@ -107,8 +140,10 @@ void add_reduce_node( void op_name(ComputeGraph& graph, const std::vector& args) { \ const IntListPtr dims_list = graph.get_int_list(args[1]); \ VK_CHECK_COND(dims_list->size() == 1); \ + const int64_t dim_val = dims_list->at(0); \ + const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \ return add_reduce_node( \ - graph, args[0], dims_list->at(0), args[out_arg_idx], #op_name); \ + graph, args[0], dim_ref, args[out_arg_idx], #op_name); \ } DEFINE_REDUCE_FN(sum, 4) diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index ee40a043ee5..044afdb8a60 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -8,6 +8,8 @@ #include +#include + #include namespace vkcompute { @@ -25,6 +27,21 @@ void resize_rotary_embedding_node( // out->virtual_resize(in_sizes); } +utils::uvec3 rotary_embedding_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef xq_out = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); + global_wg_size[0] /= 2; + + return global_wg_size; +} + void add_rotary_embedding_node( ComputeGraph& graph, const ValueRef xq, @@ -51,17 +68,11 @@ void add_rotary_embedding_node( std::string kernel_name = "rotary_embedding"; add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); - utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out); - global_wg_size[0] /= 2; - const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // Shader VK_KERNEL_FROM_STR(kernel_name), - // Workgroup sizes - global_wg_size, - local_wg_size, + rotary_embedding_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 5ef84347181..9ea76b08553 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -22,6 +23,23 @@ namespace vkcompute { +utils::uvec3 kv_cache_update_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef cache = args.at(0).refs.at(0); + const ValueRef projected = args.at(1).refs.at(0); + + if (graph->is_buffer_storage(cache)) { + return graph->create_global_wg_size(projected); + } else { + return graph->logical_limits_of(projected); + } +} + void add_kv_cache_update_node( ComputeGraph& graph, const ValueRef input_pos_symint, @@ -31,30 +49,24 @@ void add_kv_cache_update_node( add_storage_type_suffix(kernel_name, graph.storage_type_of(projected)); add_dtype_suffix(kernel_name, graph.dtype_of(projected)); - utils::uvec3 global_size; vkapi::ParamsBindList param_ubos; if (graph.is_buffer_storage(cache)) { - global_size = graph.create_global_wg_size(projected); - param_ubos = { graph.numel_ubo(projected), graph.strides_ubo(cache), graph.get_or_create_int_param_buffer(input_pos_symint)}; } else { - global_size = graph.logical_limits_of(projected); - param_ubos = { graph.logical_limits_ubo(projected), graph.get_or_create_int_param_buffer(input_pos_symint)}; } - const utils::uvec3 local_size = graph.create_local_wg_size(global_size); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + kv_cache_update_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{cache, vkapi::kWrite}, {projected, vkapi::kRead}}, // Shader param buffers @@ -69,6 +81,26 @@ void add_kv_cache_update_node( nullptr)); } +utils::uvec3 attn_weight_scale_and_mask_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + + const ValueRef attn_weight = args.at(0).refs.at(0); + + if (graph->is_buffer_storage(attn_weight)) { + return { + graph->size_at(-1, attn_weight), + graph->size_at(-2, attn_weight), + graph->size_at(-3, attn_weight), + }; + } else { + return graph->logical_limits_of(attn_weight); + } +} + void add_attn_weight_scale_and_mask_node( ComputeGraph& graph, const ValueRef input_pos_symint, @@ -81,37 +113,25 @@ void add_attn_weight_scale_and_mask_node( const int32_t head_dim_size = graph.size_at(-1, q_projected); const float scale_val = 1.0f / std::sqrt(static_cast(head_dim_size)); - utils::uvec3 global_size; - utils::uvec3 local_size; vkapi::ParamsBindList param_ubos; if (graph.is_buffer_storage(attn_weight)) { - global_size = { - graph.size_at(-1, attn_weight), - graph.size_at(-2, attn_weight), - graph.size_at(-3, attn_weight), - }; - param_ubos = { graph.sizes_ubo(attn_weight), graph.strides_ubo(attn_weight), graph.create_params_buffer(scale_val)}; } else { - global_size = graph.logical_limits_of(attn_weight); - param_ubos = { graph.logical_limits_ubo(attn_weight), graph.get_or_create_int_param_buffer(input_pos_symint), graph.create_params_buffer(scale_val)}; } - local_size = graph.create_local_wg_size(global_size); - - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - global_size, - local_size, + attn_weight_scale_and_mask_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{attn_weight, vkapi::kReadWrite}}, // Shader param buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp index 7469cbb0eb2..4aa1e9c1012 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -17,6 +18,47 @@ namespace vkcompute { using namespace utils; +utils::uvec3 pick_softmax_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + + const ValueRef out = args.at(0).refs.at(0); + const int32_t reduce_dim_xyz = + graph->extract_scalar(extra_args.at(1)); + + utils::uvec3 global_size = graph->logical_limits_of(out); + global_size[reduce_dim_xyz] = 1; + return global_size; +} + +utils::uvec3 pick_softmax_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + + const int64_t group_dim_xyz = + graph->extract_scalar(extra_args.at(2)); + + const int32_t reduce_dim_xyz = + graph->extract_scalar(extra_args.at(1)); + + // These values are hardcoded in add_softmax_node + const uint32_t nworkers_per_group = 4; + const uint32_t ngroups = 4; + + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim_xyz] = nworkers_per_group; + local_wg_size[group_dim_xyz] = ngroups; + + return local_wg_size; +} + void resize_softmax_node( ComputeGraph* graph, const std::vector& args, @@ -31,9 +73,9 @@ void resize_softmax_node( void add_softmax_node( ComputeGraph& graph, - ValueRef in, - ValueRef dim, - ValueRef out, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, bool log_softmax) { VK_CHECK_COND( !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), @@ -41,18 +83,18 @@ void add_softmax_node( const int64_t ndim = graph.dim_of(in); - int32_t reduce_dim = graph.extract_scalar(dim); - reduce_dim = normalize(reduce_dim, ndim); - reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + int32_t reduce_dim_nchw = graph.extract_scalar(dim_ref); + reduce_dim_nchw = normalize(reduce_dim_nchw, ndim); + const int32_t reduce_dim_xyz = nchw_dim_to_whcn_dim(reduce_dim_nchw, ndim); // Check that the concat dim is not the reduction dim, if the tensor has a // batch dim greater than 1. if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { VK_CHECK_COND( - graph.concat_dim_of(in) != reduce_dim, + graph.concat_dim_of(in) != reduce_dim_xyz, "Softmax shader currently does not support concat dim == reduce dim"); VK_CHECK_COND( - graph.concat_dim_of(out) != reduce_dim, + graph.concat_dim_of(out) != reduce_dim_xyz, "Softmax shader currently does not support concat dim == reduce dim"); } @@ -71,39 +113,36 @@ void add_softmax_node( const uint32_t ngroups = 4; VK_CHECK_COND(nworkers_per_group * ngroups <= max_nthreads); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); - global_wg_size[reduce_dim] = 1; - - utils::uvec3 local_wg_size{1, 1, 1}; - local_wg_size[reduce_dim] = nworkers_per_group; - const int other_dim_1 = (reduce_dim + 1) % 3; - const int other_dim_2 = (reduce_dim + 2) % 3; + // Determine the group dimension + const int other_dim_1 = (reduce_dim_xyz + 1) % 3; + const int other_dim_2 = (reduce_dim_xyz + 2) % 3; int32_t group_dim; + utils::uvec3 global_wg_size = graph.logical_limits_of(out); if (global_wg_size[other_dim_1] > global_wg_size[other_dim_2]) { - local_wg_size[other_dim_1] = ngroups; group_dim = other_dim_1; } else { - local_wg_size[other_dim_2] = ngroups; group_dim = other_dim_2; } - graph.execute_nodes().emplace_back(new DispatchNode( + const ValueRef reduce_dim_xyz_ref = + graph.get_or_add_value_for_int(reduce_dim_xyz); + const ValueRef group_dim_xyz_ref = graph.get_or_add_value_for_int(group_dim); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, - // shader_descriptor, VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - local_wg_size, + pick_softmax_global_wg_size, + pick_softmax_local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers {graph.logical_limits_ubo(out), graph.sizes_ubo(in)}, // Push Constants {}, // Specialization Constants - {graph.packed_dim_of(out), reduce_dim, group_dim}, + {graph.packed_dim_of(out), reduce_dim_xyz, group_dim}, // Resize Args - {}, + {dim_ref, reduce_dim_xyz_ref, group_dim_xyz_ref}, // Resizing Logic resize_softmax_node)); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 8c060a9da4b..e8d5420af92 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -10,7 +10,8 @@ #include -#include +#include +#include #include #include @@ -38,11 +39,11 @@ void add_staging_to_tensor_node( pcs = {graph.sizes_pc_of(out_tensor)}; } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - graph.create_global_wg_size(out_tensor), - graph.create_local_wg_size(out_tensor), + default_pick_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{out_tensor, vkapi::kWrite}, {in_staging, vkapi::kRead}}, // Parameter Buffers @@ -65,6 +66,41 @@ bool is_bitw8_shader(const vkapi::ShaderInfo& shader) { return shader_prefix_str == kBitw8PrefixStr; } +vkapi::ShaderInfo get_tensor_to_staging_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef in_tensor = args.at(1).refs.at(0); + return get_tensor_to_nchw_shader( + *graph->get_tensor(in_tensor), graph->int8_buffers_enabled()); +} + +utils::uvec3 tensor_to_staging_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef in_tensor = args.at(1).refs.at(0); + const ValueRef out_staging = args.at(0).refs.at(0); + + utils::uvec3 global_wg_size = graph->create_global_wg_size(in_tensor); + + // Normally, the image_to_nchw shader is structured so that each thread reads + // one texel from the input texture and writes each component of the texel + // into the corresponding location in the output buffer. However, this shader + // is structured slightly differently in that each thread writes out a + // complete 32 bit integer (containing 4 packed 8-bit integers) into the + // output buffer. Therefore, the global work group size for this shader will + // be the number of elements in the output buffer divided by 4, as opposed to + // the extents of the input texture. + if (is_bitw8_shader(shader)) { + const uint32_t buffer_len = graph->get_staging(out_staging)->numel() / 4; + global_wg_size = {buffer_len, 1, 1}; + } + + return global_wg_size; +} + void add_tensor_to_staging_node( ComputeGraph& graph, const ValueRef in_tensor, @@ -74,8 +110,6 @@ void add_tensor_to_staging_node( vkapi::ShaderInfo shader = get_tensor_to_nchw_shader( *graph.get_tensor(in_tensor), graph.int8_buffers_enabled()); - utils::uvec3 global_wg_size = graph.create_global_wg_size(in_tensor); - vkapi::ParamsBindList ubos; if (graph.is_buffer_storage(in_tensor)) { ubos.append( @@ -86,25 +120,15 @@ void add_tensor_to_staging_node( ubos.append({graph.sizes_ubo(in_tensor)}); } - // Normally, the image_to_nchw shader is structured so that each thread reads - // one texel from the input texture and writes each component of the texel - // into the corresponding location in the output buffer. However, this shader - // is structured slightly differently in that each thread writes out a - // complete 32 bit integer (containing 4 packed 8-bit integers) into the - // output buffer. Therefore, the global work group size for this shader will - // be the number of elements in the output buffer divided by 4, as opposed to - // the extents of the input texture. if (is_bitw8_shader(shader)) { - uint32_t buffer_len = graph.get_staging(out_staging)->numel() / 4; - global_wg_size = {buffer_len, 1, 1}; ubos.append({graph.numel_ubo(in_tensor)}); } - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, - global_wg_size, - graph.create_local_wg_size(global_wg_size), + tensor_to_staging_global_wg_size, + default_pick_local_wg_size, // Input and Outputs {{out_staging, vkapi::kWrite}, {in_tensor, vkapi::kRead}}, // Parameter Buffers diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index bffa8e2a181..518148f12eb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -51,14 +52,13 @@ void add_unary_op_node( ubos.append( {graph.create_params_buffer(min), graph.create_params_buffer(max)}); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {in, vkapi::MemoryAccessType::READ}}, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers ubos, // Push Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 710ba0d576f..9dbe79faebb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -8,6 +8,7 @@ #include +#include #include #include @@ -67,11 +68,11 @@ void add_view_node( kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, *t_out); - graph.execute_nodes().emplace_back(new DispatchNode( + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + default_pick_global_wg_size, + default_pick_local_wg_size, // Inputs and Outputs {{out, vkapi::MemoryAccessType::WRITE}, {in, vkapi::MemoryAccessType::READ}},