From b57607158be68fdf9750513f328c50e3ef57f806 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Fri, 30 May 2025 08:26:02 -0700 Subject: [PATCH] [ET-VK][ez] Updates to DynamicDispatchNode ## Changes * Pass in global work group size to the local work group size determination function in `DynamicDIspatchNode` ## Motivation Oftentimes it is useful to know what the global work group size is when determining what the local group group size should be. ## Performance Impact None. Differential Revision: [D75686047](https://our.internmc.facebook.com/intern/diff/D75686047/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/DynamicDispatchNode.cpp | 10 +++++++--- .../vulkan/runtime/graph/ops/DynamicDispatchNode.h | 1 + backends/vulkan/test/vulkan_compute_api_test.cpp | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp index ac84916c6fa..1b313f064d6 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp @@ -27,7 +27,11 @@ DynamicDispatchNode::DynamicDispatchNode( graph, pick_shader_fn(&graph, args, resize_args), pick_global_wg_fn(&graph, args, resize_args), - pick_local_wg_fn(&graph, args, resize_args), + pick_local_wg_fn( + &graph, + pick_global_wg_fn(&graph, args, resize_args), + args, + resize_args), args, params, push_constants, @@ -41,8 +45,8 @@ DynamicDispatchNode::DynamicDispatchNode( void DynamicDispatchNode::encode(ComputeGraph* graph) { shader_ = pick_shader_fn_(graph, args_, resize_args_); global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_); - local_workgroup_size_ = - utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_)); + local_workgroup_size_ = utils::WorkgroupSize( + pick_local_wg_fn_(graph, global_workgroup_size_, args_, resize_args_)); DispatchNode::encode(graph); } diff --git a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h index ede50941415..6fd2441a94f 100644 --- a/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h +++ b/backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h @@ -36,6 +36,7 @@ class DynamicDispatchNode final : public DispatchNode { const std::vector&)>; using PickLocalFn = const std::function&, const std::vector&)>; diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index a6475d95d07..af60680e6fe 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -3324,6 +3324,7 @@ utils::uvec3 pick_dynamic_dispatch_global_wg_size( utils::uvec3 pick_dynamic_dispatch_local_wg_size( ComputeGraph* graph, + const utils::uvec3& global_workgroup_size, const std::vector& args, const std::vector& additional_args) { return {64, 1, 1};