Skip to content

Commit 17987a3

Browse files
committed
[ET-VK][ez] Updates to DynamicDispatchNode
Pull Request resolved: #11254 ## Changes * Pass in global work group size to the local work group size determination function in `DynamicDIspatchNode` ## Motivation Oftentimes it is useful to know what the global work group size is when determining what the local group group size should be. ## Performance Impact None. ghstack-source-id: 287234293 @exported-using-ghexport Differential Revision: [D75686047](https://our.internmc.facebook.com/intern/diff/D75686047/)
1 parent 3215a47 commit 17987a3

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ DynamicDispatchNode::DynamicDispatchNode(
2727
graph,
2828
pick_shader_fn(&graph, args, resize_args),
2929
pick_global_wg_fn(&graph, args, resize_args),
30-
pick_local_wg_fn(&graph, args, resize_args),
30+
pick_local_wg_fn(
31+
&graph,
32+
pick_global_wg_fn(&graph, args, resize_args),
33+
args,
34+
resize_args),
3135
args,
3236
params,
3337
push_constants,
@@ -41,8 +45,8 @@ DynamicDispatchNode::DynamicDispatchNode(
4145
void DynamicDispatchNode::encode(ComputeGraph* graph) {
4246
shader_ = pick_shader_fn_(graph, args_, resize_args_);
4347
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44-
local_workgroup_size_ =
45-
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
48+
local_workgroup_size_ = utils::WorkgroupSize(
49+
pick_local_wg_fn_(graph, global_workgroup_size_, args_, resize_args_));
4650
DispatchNode::encode(graph);
4751
}
4852

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DynamicDispatchNode final : public DispatchNode {
3636
const std::vector<ValueRef>&)>;
3737
using PickLocalFn = const std::function<utils::uvec3(
3838
ComputeGraph*,
39+
const utils::uvec3& global_workgroup_size,
3940
const std::vector<ArgGroup>&,
4041
const std::vector<ValueRef>&)>;
4142

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3323,6 +3323,7 @@ utils::uvec3 pick_dynamic_dispatch_global_wg_size(
33233323

33243324
utils::uvec3 pick_dynamic_dispatch_local_wg_size(
33253325
ComputeGraph* graph,
3326+
const utils::uvec3& global_workgroup_size,
33263327
const std::vector<ArgGroup>& args,
33273328
const std::vector<ValueRef>& additional_args) {
33283329
return {64, 1, 1};

0 commit comments

Comments
 (0)