diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ac84916c6fa..a8d2fe2e99d 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -25,9 +25,9 @@ DynamicDispatchNode::DynamicDispatchNode( const ResizeFunction& resize_fn) : DispatchNode( graph, - pick_shader_fn(&graph, args, resize_args), - pick_global_wg_fn(&graph, args, resize_args), - pick_local_wg_fn(&graph, args, resize_args), + vkapi::ShaderInfo(), + {1u, 1u, 1u}, + {1u, 1u, 1u}, args, params, push_constants, @@ -36,13 +36,57 @@ DynamicDispatchNode::DynamicDispatchNode( resize_fn), pick_shader_fn_(pick_shader_fn), pick_global_wg_fn_(pick_global_wg_fn), + pick_local_wg_fn_(pick_local_wg_fn) { + shader_ = pick_shader_fn(&graph, args, resize_args); + global_workgroup_size_ = + pick_global_wg_fn(&graph, shader_, args, resize_args); + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn( + &graph, shader_, global_workgroup_size_, args, resize_args)); +} + +DynamicDispatchNode::DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn) + : DispatchNode( + graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + pick_local_wg_fn( + &graph, + shader, + pick_global_wg_fn(&graph, shader, args, resize_args), + args, + resize_args), + args, + params, + push_constants, + spec_vars, + resize_args, + resize_fn), + pick_shader_fn_{nullptr}, + pick_global_wg_fn_(pick_global_wg_fn), pick_local_wg_fn_(pick_local_wg_fn) {} void DynamicDispatchNode::encode(ComputeGraph* graph) { - shader_ = pick_shader_fn_(graph, args_, resize_args_); - global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); - local_workgroup_size_ = - utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + if (pick_shader_fn_) { + shader_ = pick_shader_fn_(graph, args_, resize_args_); + } + if (pick_global_wg_fn_) { + global_workgroup_size_ = + pick_global_wg_fn_(graph, shader_, args_, resize_args_); + } + if (pick_local_wg_fn_) { + local_workgroup_size_ = utils::WorkgroupSize(pick_local_wg_fn_( + graph, shader_, global_workgroup_size_, args_, resize_args_)); + } DispatchNode::encode(graph); } diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index ede50941415..005151272c3 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -32,10 +32,13 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector&)>; using PickGlobalFn = const std::function&, const std::vector&)>; using PickLocalFn = const std::function&, const std::vector&)>; @@ -51,6 +54,18 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector& resize_args, const ResizeFunction& resize_fn = nullptr); + explicit DynamicDispatchNode( + ComputeGraph& graph, + const vkapi::ShaderInfo& shader, + const PickGlobalFn& pick_global_wg_fn, + const PickLocalFn& pick_local_wg_fn, + const std::vector& args, + const vkapi::ParamsBindList& params, + const std::vector& push_constants, + const vkapi::SpecVarList& spec_vars, + const std::vector& resize_args, + const ResizeFunction& resize_fn = nullptr); + ~DynamicDispatchNode() override = default; void encode(ComputeGraph* graph) override; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index f014cc79f56..60dfb3b8606 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -3314,17 +3315,23 @@ vkapi::ShaderInfo pick_dynamic_dispatch_shader( utils::uvec3 pick_dynamic_dispatch_global_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)shader; const ValueRef out = args[0].refs[0]; - return graph->logical_limits_of(out); } utils::uvec3 pick_dynamic_dispatch_local_wg_size( ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, const std::vector& args, - const std::vector& additional_args) { + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; return {64, 1, 1}; }