Skip to content

Commit b576071

Browse files
committed
[ET-VK][ez] Updates to DynamicDispatchNode
## Changes * Pass in global work group size to the local work group size determination function in `DynamicDIspatchNode` ## Motivation Oftentimes it is useful to know what the global work group size is when determining what the local group group size should be. ## Performance Impact None. Differential Revision: [D75686047](https://our.internmc.facebook.com/intern/diff/D75686047/) [ghstack-poisoned]
1 parent 85726c4 commit b576071

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ DynamicDispatchNode::DynamicDispatchNode(
2727
graph,
2828
pick_shader_fn(&graph, args, resize_args),
2929
pick_global_wg_fn(&graph, args, resize_args),
30-
pick_local_wg_fn(&graph, args, resize_args),
30+
pick_local_wg_fn(
31+
&graph,
32+
pick_global_wg_fn(&graph, args, resize_args),
33+
args,
34+
resize_args),
3135
args,
3236
params,
3337
push_constants,
@@ -41,8 +45,8 @@ DynamicDispatchNode::DynamicDispatchNode(
4145
void DynamicDispatchNode::encode(ComputeGraph* graph) {
4246
shader_ = pick_shader_fn_(graph, args_, resize_args_);
4347
global_workgroup_size_ = pick_global_wg_fn_(graph, args_, resize_args_);
44-
local_workgroup_size_ =
45-
utils::WorkgroupSize(pick_local_wg_fn_(graph, args_, resize_args_));
48+
local_workgroup_size_ = utils::WorkgroupSize(
49+
pick_local_wg_fn_(graph, global_workgroup_size_, args_, resize_args_));
4650
DispatchNode::encode(graph);
4751
}
4852

backends/vulkan/runtime/graph/ops/DynamicDispatchNode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class DynamicDispatchNode final : public DispatchNode {
3636
const std::vector<ValueRef>&)>;
3737
using PickLocalFn = const std::function<utils::uvec3(
3838
ComputeGraph*,
39+
const utils::uvec3& global_workgroup_size,
3940
const std::vector<ArgGroup>&,
4041
const std::vector<ValueRef>&)>;
4142

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3324,6 +3324,7 @@ utils::uvec3 pick_dynamic_dispatch_global_wg_size(
33243324

33253325
utils::uvec3 pick_dynamic_dispatch_local_wg_size(
33263326
ComputeGraph* graph,
3327+
const utils::uvec3& global_workgroup_size,
33273328
const std::vector<ArgGroup>& args,
33283329
const std::vector<ValueRef>& additional_args) {
33293330
return {64, 1, 1};

0 commit comments

Comments
 (0)