Skip to content

Commit b449ac4

Browse files
committed
[ET-VK] Replacing use of adaptive_work_group_size function by create_local_wg_size function.
This diff replaces the use of the adaptive_work_group_size function with create_local_wg_size function, which is better tuned for improving shader performance. Differential Revision: [D66308779](https://our.internmc.facebook.com/intern/diff/D66308779/) [ghstack-poisoned]
1 parent 700a473 commit b449ac4

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

backends/vulkan/runtime/graph/ops/impl/Copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ void add_copy_channel_offset_node(
135135
utils::safe_downcast<uint32_t>(dim_at<kWidth4D>(in_sizes)),
136136
utils::safe_downcast<uint32_t>(dim_at<kHeight4D>(in_sizes)),
137137
utils::safe_downcast<uint32_t>(dst_last_z - dst_first_z + 1)};
138-
uvec3 local_size = adaptive_work_group_size(global_size);
138+
uvec3 local_size = graph.create_local_wg_size(global_size);
139139

140140
const struct Block final {
141141
ivec3 range;

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ void add_addmm_optimized_node(
198198
} else {
199199
global_size = utils::divup_vec(global_size, {4, 4, 1});
200200
}
201-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
201+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
202202

203203
graph.execute_nodes().emplace_back(new DispatchNode(
204204
graph,

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ void add_matmul_optimized_node(
213213
global_size = utils::divup_vec(global_size, {4, 4, 1});
214214
}
215215

216-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
216+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
217217

218218
graph.execute_nodes().emplace_back(new DispatchNode(
219219
graph,

backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ void add_native_layer_norm_node(
8989
std::vector<int64_t> in_sizes = t_input->sizes();
9090

9191
utils::uvec3 global_size = t_mean->logical_limits();
92-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
92+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
9393

9494
std::string kernel_name("native_layer_norm");
9595
kernel_name.reserve(kShaderNameReserve);

backends/vulkan/runtime/graph/ops/impl/Pool.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ void add_max_pool2d_node(
7979
check_pool2d_args(*t_in, *t_out);
8080

8181
utils::uvec3 global_size = t_out->logical_limits();
82-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
82+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
8383

8484
std::string kernel_name("max_pool2d");
8585
add_dtype_suffix(kernel_name, *t_out);
@@ -154,7 +154,7 @@ void add_avg_pool2d_node(
154154
check_pool2d_args(*t_in, *t_out);
155155

156156
utils::uvec3 global_size = t_out->logical_limits();
157-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
157+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
158158

159159
std::string kernel_name("avg_pool2d");
160160
add_dtype_suffix(kernel_name, *t_out);

backends/vulkan/runtime/graph/ops/impl/Repeat.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ void add_repeat_channel_node(
9090
// Channel packed global work ids
9191
running_range[2] = out_whcn_sizes[3] * utils::div_up_4(out_whcn_sizes[2]);
9292
utils::uvec3 global_size = utils::make_uvec3(running_range);
93-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
93+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
9494

9595
const struct Block final {
9696
utils::ivec4 out_sizes;

backends/vulkan/runtime/graph/ops/impl/Slice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void add_slice_tensor_copy_node(
126126
add_dtype_suffix(kernel_name, *t_out);
127127

128128
utils::uvec3 global_size = t_out->logical_limits();
129-
utils::uvec3 local_size = adaptive_work_group_size(global_size);
129+
utils::uvec3 local_size = graph.create_local_wg_size(global_size);
130130

131131
const struct Block final {
132132
int dim;

0 commit comments

Comments
 (0)