diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index d09597ad778..32763417fc0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -21,6 +21,7 @@ #include #include +#include #include #include diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.h b/backends/vulkan/runtime/graph/ops/DispatchNode.h index 172ab49a98a..c45f0a741fd 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.h @@ -22,7 +22,7 @@ class ComputeGraph; /* * Represents a single shader execution op in a ML model. */ -class DispatchNode final : public ExecuteNode { +class DispatchNode : public ExecuteNode { friend class ComputeGraph; public: @@ -43,9 +43,9 @@ class DispatchNode final : public ExecuteNode { void encode(ComputeGraph* graph) override; protected: - const vkapi::ShaderInfo shader_; - const utils::uvec3 global_workgroup_size_; - const utils::WorkgroupSize local_workgroup_size_; + vkapi::ShaderInfo shader_; + utils::uvec3 global_workgroup_size_; + utils::WorkgroupSize local_workgroup_size_; const vkapi::ParamsBindList params_; const vkapi::SpecVarList spec_vars_; const std::vector push_constants_; diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp new file mode 100644 index 00000000000..ac84916c6fa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { + +DynamicDispatchNode::DynamicDispatchNode( + ComputeGraph& graph, + const PickShaderFn& pick_shader_fn, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn) + : DispatchNode( + graph, + pick_shader_fn(&graph, args, resize_args), + pick_global_wg_fn(&graph, args, resize_args), + pick_local_wg_fn(&graph, args, resize_args), + args, + params, + push_constants, + spec_vars, + resize_args, + resize_fn), + pick_shader_fn_(pick_shader_fn), + pick_global_wg_fn_(pick_global_wg_fn), + pick_local_wg_fn_(pick_local_wg_fn) {} + +void DynamicDispatchNode::encode(ComputeGraph* graph) { + shader_ = pick_shader_fn_(graph, args_, resize_args_); + global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); + local_workgroup_size_ = + utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + DispatchNode::encode(graph); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h new file mode 100644 index 00000000000..ede50941415 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -0,0 +1,69 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +#include +#include + +#include + +namespace vkcompute { + +class ComputeGraph; + +/* + * Represents a single shader execution op in a ML model. + */ +class DynamicDispatchNode final : public DispatchNode { + friend class ComputeGraph; + + public: + using PickShaderFn = const std::function&, + const std::vector&)>; + using PickGlobalFn = const std::function&, + const std::vector&)>; + using PickLocalFn = const std::function&, + const std::vector&)>; + + explicit DynamicDispatchNode( + ComputeGraph& graph, + const PickShaderFn& pick_shader_fn, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn = nullptr); + + ~DynamicDispatchNode() override = default; + + void encode(ComputeGraph* graph) override; + + protected: + const PickShaderFn pick_shader_fn_; + const PickGlobalFn pick_global_wg_fn_; + const PickLocalFn pick_local_wg_fn_; + + public: + operator bool() const { + return shader_; + } +}; + +} // namespace vkcompute diff --git a/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl b/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl new file mode 100644 index 00000000000..341da3eeacd --- /dev/null +++ b/backends/vulkan/test/glsl/dynamic_dispatch_test.glsl @@ -0,0 +1,45 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", "float", "texture3d")} +${layout_declare_tensor(1, "r", "t_in1", "float", "texture3d")} +${layout_declare_tensor(2, "r", "t_in2", "float", "texture3d")} + +layout(push_constant) uniform restrict Block { + ivec4 out_sizes; + ivec4 in1_sizes; + ivec4 in2_sizes; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec3 pos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(pos, out_sizes.xyz))) { + return; + } + + + vec4 out_texel = vec4(0.0); + for (int row = 0; row < in1_sizes.y; ++row) { + ivec3 in_pos = ivec3(pos.x, row, pos.z); + vec4 in1_texel = texelFetch(t_in1, in_pos, 0); + vec4 in2_texel = texelFetch(t_in2, in_pos, 0); + + out_texel += in1_texel * in2_texel; + } + + imageStore(t_out, pos, out_texel + ${OFFSET}); +} diff --git a/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml b/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml new file mode 100644 index 00000000000..0f0f5f51685 --- /dev/null +++ b/backends/vulkan/test/glsl/dynamic_dispatch_test.yaml @@ -0,0 +1,7 @@ +dynamic_dispatch_test: + parameter_names_with_default_values: + OFFSET: 2.25 + shader_variants: + - NAME: dynamic_dispatch_test_var1 + - NAME: dynamic_dispatch_test_var2 + OFFSET: 5.5 diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index cf42a846db5..a6475d95d07 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -3297,3 +3297,140 @@ TEST(VulkanComputeGraphOpsTest, test_to_copy) { test_to_copy(); } } + +vkapi::ShaderInfo pick_dynamic_dispatch_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& additional_args) { + const ValueRef mat1 = args[1].refs[0]; + + std::string kernel_name = "dynamic_dispatch_test"; + if (graph->size_at(-2, mat1) == 1) { + kernel_name += "_var1"; + } else { + kernel_name += "_var2"; + } + return VK_KERNEL_FROM_STR(kernel_name); +} + +utils::uvec3 pick_dynamic_dispatch_global_wg_size( + ComputeGraph* graph, + const std::vector& args, + const std::vector& additional_args) { + const ValueRef out = args[0].refs[0]; + + return graph->logical_limits_of(out); +} + +utils::uvec3 pick_dynamic_dispatch_local_wg_size( + ComputeGraph* graph, + const std::vector& args, + const std::vector& additional_args) { + return {64, 1, 1}; +} + +void resize_dynamic_dispatch_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& additional_args) { + const ValueRef out = args[0].refs[0]; + const ValueRef mat1 = args[1].refs[0]; + + std::vector out_sizes = graph->sizes_of(mat1); + out_sizes.at(out_sizes.size() - 2) = 1; + + graph->get_tensor(out)->virtual_resize(out_sizes); +} + +void add_dynamic_dispatch_test_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2, + const ValueRef out) { + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + pick_dynamic_dispatch_shader, + pick_dynamic_dispatch_global_wg_size, + pick_dynamic_dispatch_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}}, + // Shader params buffers + {}, + // Push Constants + {graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(mat2)}, + // Specialization constants + {}, + // Resize Logic + {}, + resize_dynamic_dispatch_node)); +} + +vkcompute::ComputeGraph build_dynamic_dispatch_test_graph(int M, int N) { + using namespace vkcompute; + GraphConfig config; + ComputeGraph graph(config); + + vkapi::ScalarType dtype = vkapi::kFloat; + utils::StorageType in_out_stype = utils::kTexture3D; + utils::GPUMemoryLayout memory_layout = utils::kWidthPacked; + + std::vector mat1_size = {M, N}; + std::vector mat2_size = {M, N}; + std::vector out_size = {1, N}; + + IOValueRef mat1 = + graph.add_input_tensor(mat1_size, dtype, in_out_stype, memory_layout); + IOValueRef mat2{}; + + mat2.value = graph.add_tensor(mat2_size, dtype, in_out_stype, memory_layout); + mat2.staging = graph.set_input_tensor(mat2.value); + + IOValueRef out; + out.value = graph.add_tensor(out_size, dtype, in_out_stype, memory_layout); + + add_dynamic_dispatch_test_node(graph, mat1, mat2, out); + + out.staging = graph.set_output_tensor(out.value); + + return graph; +} + +void test_dynamic_dispatch(int M, int N) { + ComputeGraph graph = build_dynamic_dispatch_test_graph(M, N); + + graph.prepare(); + graph.encode_prepack(); + graph.prepack(); + graph.encode_execute(); + + for (int i = 1; i < 4; i++) { + float val_mat1 = i; + float val_mat2 = i + 1; + // 5.3 is a hardcoded offset in the compute shader + float val_out = M * (val_mat1 * val_mat2) + 5.5; + execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out}); + } + + // Switch to GEMV mode + int new_N = N / 2; + std::vector new_mat1_size = {1, new_N}; + std::vector new_mat2_size = {1, new_N}; + graph.resize_input(0, new_mat1_size); + graph.resize_input(1, new_mat2_size); + graph.propagate_resize(); + + graph.encode_execute(); + + for (int i = 1; i < 4; i++) { + float val_mat1 = i; + float val_mat2 = i + 1; + float val_out = (val_mat1 * val_mat2) + 2.25; + execute_graph_and_check_output(graph, {val_mat1, val_mat2}, {val_out}); + } +} + +TEST(VulkanComputeGraphOpsTest, test_dynamic_dispatch_graph) { + test_dynamic_dispatch(128, 128); +}