From e4fd85db17cdaeb695c9f68344d86901cd760aea Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 19 May 2025 11:55:23 -0700 Subject: [PATCH] [ET-VK] Introduce `DynamicDispatchNode` ## Context The `DynamicDispatchNode` class in introduced in this diff to allow for shader re-selection upon input resize. See the previous diff in the stack for more context on why this functionality is needed. Differential Revision: [D75013780](https://our.internmc.facebook.com/intern/diff/D75013780/) [ghstack-poisoned] --- backends/vulkan/runtime/graph/ComputeGraph.h | 1 + .../vulkan/runtime/graph/ops/DispatchNode.cpp | 4 + .../vulkan/runtime/graph/ops/DispatchNode.h | 8 +- .../runtime/graph/ops/DynamicDispatchNode.cpp | 53 +++++++ .../runtime/graph/ops/DynamicDispatchNode.h | 69 +++++++++ .../test/glsl/dynamic_dispatch_test.glsl | 45 ++++++ .../test/glsl/dynamic_dispatch_test.yaml | 7 + .../vulkan/test/vulkan_compute_api_test.cpp | 137 ++++++++++++++++++ 8 files changed, 320 insertions(+), 4 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp create mode 100644 backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h create mode 100644 backends/vulkan/test/glsl/dynamic_dispatch_test.glsl create mode 100644 backends/vulkan/test/glsl/dynamic_dispatch_test.yaml diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d09597ad778..32763417fc0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -21,6 +21,7 @@ #include #include +#include #include #include diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 51ff0c122b0..166c9077bc8 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -12,6 +12,8 @@ #include +#include + namespace vkcompute { DispatchNode::DispatchNode( @@ -39,6 +41,8 @@ void DispatchNode::encode(ComputeGraph* graph) { if (!shader_) { return; } + std::cout << "dynamically dispatching... " << shader_.kernel_name + << std::endl; api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index 172ab49a98a..c45f0a741fd 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -22,7 +22,7 @@ class ComputeGraph; /* * Represents a single shader execution op in a ML model. */ -class DispatchNode final : public ExecuteNode { +class DispatchNode : public ExecuteNode { friend class ComputeGraph; public: @@ -43,9 +43,9 @@ class DispatchNode final : public ExecuteNode { void encode(ComputeGraph* graph) override; protected: - const vkapi::ShaderInfo shader_; - const utils::uvec3 global_workgroup_size_; - const utils::WorkgroupSize local_workgroup_size_; + vkapi::ShaderInfo shader_; + utils::uvec3 global_workgroup_size_; + utils::WorkgroupSize local_workgroup_size_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp new file mode 100644 index 00000000000..d4a9aa28cb7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -0,0 +1,53 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +#include + +namespace vkcompute { + +DynamicDispatchNode::DynamicDispatchNode( + ComputeGraph& graph, + const PickShaderFn& pick_shader_fn, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn) + : DispatchNode( + graph, + pick_shader_fn(&graph, args, resize_args), + pick_global_wg_fn(&graph, args, resize_args), + pick_local_wg_fn(&graph, args, resize_args), + args, + params, + spec_vars, + resize_fn, + resize_args, + push_constants), + pick_shader_fn_(pick_shader_fn), + pick_global_wg_fn_(pick_global_wg_fn), + pick_local_wg_fn_(pick_local_wg_fn) {} + +void DynamicDispatchNode::encode(ComputeGraph* graph) { + shader_ = pick_shader_fn_(graph, args_, resize_args_); + global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); + local_workgroup_size_ = + utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + DispatchNode::encode(graph); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h new file mode 100644 index 00000000000..1c0ef9acea1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#include + +namespace vkcompute { + +class ComputeGraph; + +/* + * Represents a single shader execution op in a ML model. + */ +class DynamicDispatchNode final : public DispatchNode { + friend class ComputeGraph; + + public: + using PickShaderFn = const std::function&, + const std::vector)>; + using PickGlobalFn = const std::function&, + const std::vector)>; + using PickLocalFn = const std::function&, + const std::vector)>; + + explicit DynamicDispatchNode( + ComputeGraph& graph, + const PickShaderFn& pick_shader_fn, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn = nullptr); + + ~DynamicDispatchNode() override = default; + + void encode(ComputeGraph* graph) override; + + protected: + const PickShaderFn pick_shader_fn_; + const PickGlobalFn pick_global_wg_fn_; + const PickLocalFn pick_local_wg_fn_; + + public: + operator bool() const { + return shader_; + } +}; + +} // namespace vkcompute diff --git a/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl b/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl new file mode 100644 index 00000000000..341da3eeacd --- /dev/null +++ b/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", "float", "texture3d")} +${layout_declare_tensor(1, "r", "t_in1", "float", "texture3d")} +${layout_declare_tensor(2, "r", "t_in2", "float", "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in1_sizes; + ivec4 in2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_sizes.xyz))) { + return; + } + + + vec4 out_texel = vec4(0.0); + for (int row = 0; row < in1_sizes.y; ++row) { + ivec3 in_pos = ivec3(pos.x, row, pos.z); + vec4 in1_texel = texelFetch(t_in1, in_pos, 0); + vec4 in2_texel = texelFetch(t_in2, in_pos, 0); + + out_texel += in1_texel * in2_texel; + } + + imageStore(t_out, pos, out_texel + ${OFFSET}); +} diff --git a/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml b/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml new file mode 100644 index 00000000000..0f0f5f51685 --- /dev/null +++ b/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml @@ -0,0 +1,7 @@ +dynamic_dispatch_test: + parameter_names_with_default_values: + OFFSET: 2.25 + shader_variants: + - NAME: dynamic_dispatch_test_var1 + - NAME: dynamic_dispatch_test_var2 + OFFSET: 5.5 diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index cf42a846db5..707e7a5ca2c 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) { test_to_copy(); } } + +vkapi::ShaderInfo pick_dynamic_dispatch_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector additional_args) { + const ValueRef mat1 = args[1].refs[0]; + + std::string kernel_name = "dynamic_dispatch_test"; + if (graph->size_at(-2, mat1) == 1) { + kernel_name += "_var1"; + } else { + kernel_name += "_var2"; + } + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 pick_dynamic_dispatch_global_wg_size( + ComputeGraph* graph, + const std::vector& args, + const std::vector additional_args) { + const ValueRef out = args[0].refs[0]; + + return graph->logical_limits_of(out); +} + +utils::uvec3 pick_dynamic_dispatch_local_wg_size( + ComputeGraph* graph, + const std::vector& args, + const std::vector additional_args) { + return {64, 1, 1}; +} + +void resize_dynamic_dispatch_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector additional_args) { + const ValueRef out = args[0].refs[0]; + const ValueRef mat1 = args[1].refs[0]; + + std::vector out_sizes = graph->sizes_of(mat1); + out_sizes.at(out_sizes.size() - 2) = 1; + + graph->get_tensor(out)->virtual_resize(out_sizes); +} + +void add_dynamic_dispatch_test_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2, + const ValueRef out) { + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_dynamic_dispatch_shader, + pick_dynamic_dispatch_global_wg_size, + pick_dynamic_dispatch_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, + // Shader params buffers + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)}, + // Specialization constants + {}, + // Resize Logic + {}, + resize_dynamic_dispatch_node)); +} + +vkcompute::ComputeGraph build_dynamic_dispatch_test_graph(int M, int N) { + using namespace vkcompute; + GraphConfig config; + ComputeGraph graph(config); + + vkapi::ScalarType dtype = vkapi::kFloat; + utils::StorageType in_out_stype = utils::kTexture3D; + utils::GPUMemoryLayout memory_layout = utils::kWidthPacked; + + std::vector mat1_size = {M, N}; + std::vector mat2_size = {M, N}; + std::vector out_size = {1, N}; + + IOValueRef mat1 = + graph.add_input_tensor(mat1_size, dtype, in_out_stype, memory_layout); + IOValueRef mat2{}; + + mat2.value = graph.add_tensor(mat2_size, dtype, in_out_stype, memory_layout); + mat2.staging = graph.set_input_tensor(mat2.value); + + IOValueRef out; + out.value = graph.add_tensor(out_size, dtype, in_out_stype, memory_layout); + + add_dynamic_dispatch_test_node(graph, mat1, mat2, out); + + out.staging = graph.set_output_tensor(out.value); + + return graph; +} + +void test_dynamic_dispatch(int M, int N) { + ComputeGraph graph = build_dynamic_dispatch_test_graph(M, N); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + for (int i = 1; i < 4; i++) { + float val_mat1 = i; + float val_mat2 = i + 1; + // 5.3 is a hardcoded offset in the compute shader + float val_out = M * (val_mat1 * val_mat2) + 5.5; + execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out}); + } + + // Switch to GEMV mode + int new_N = N / 2; + std::vector new_mat1_size = {1, new_N}; + std::vector new_mat2_size = {1, new_N}; + graph.resize_input(0, new_mat1_size); + graph.resize_input(1, new_mat2_size); + graph.propagate_resize(); + + graph.encode_execute(); + + for (int i = 1; i < 4; i++) { + float val_mat1 = i; + float val_mat2 = i + 1; + float val_out = (val_mat1 * val_mat2) + 2.25; + execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out}); + } +} + +TEST(VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) { + test_dynamic_dispatch(128, 128); +}