Skip to content

Commit 59f7b82

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Calculating axis's local wg size based on global workload and making it as close as possible to warp size of 32.
Summary: This diff changes the local workgroup size calculation logic in the Vulkan backend of Executorch. The workgroup size of the largest axis is kept largest so workgroups are better occupied. The workgroup size is calculated based on the warp size of 32. When kernel is 2 dimensional largest axis is kept close to warp size it, so threads in the same warp Read / Write to consecutive memory locations, thus improving performance. Reviewed By: SS-JIA Differential Revision: D64418632
1 parent c7474f8 commit 59f7b82

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -485,24 +485,46 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
485485
return config_.local_wg_size_override;
486486
}
487487

488-
utils::uvec3 local_group_size = {4, 4, 4};
488+
// array containing axis index and global workgroup size
489+
std::pair<uint32_t,uint32_t> global_wg_size_desc[] = {
490+
{0u, global_wg_size[0]},
491+
{1u, global_wg_size[1]},
492+
{2u, global_wg_size[2]}};
493+
494+
// sort the global workgroup size in descending order
495+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
496+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
497+
}
498+
if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) {
499+
std::swap(global_wg_size_desc[1], global_wg_size_desc[2]);
500+
}
501+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
502+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
503+
}
489504

490-
if (global_wg_size[2u] == 1) {
491-
if (global_wg_size[1u] == 1) {
505+
utils::uvec3 local_group_size = {
506+
8,
507+
std::max(1u, std::min(4u, global_wg_size_desc[1].second)),
508+
std::max(1u, std::min(2u, global_wg_size_desc[2].second))
509+
};
510+
511+
if (global_wg_size_desc[2u].second == 1) {
512+
if (global_wg_size_desc[1u].second == 1) {
492513
local_group_size[0u] = 64;
493514
local_group_size[1u] = 1;
494-
local_group_size[2u] = 1;
495-
} else if (global_wg_size[1u] < 8) {
515+
} else if (global_wg_size_desc[1u].second % 4 == 0) {
496516
local_group_size[0u] = 16;
497517
local_group_size[1u] = 4;
498-
local_group_size[2u] = 1;
499518
} else {
500-
local_group_size[0u] = 8;
501-
local_group_size[1u] = 8;
502-
local_group_size[2u] = 1;
519+
local_group_size[0u] = 32;
520+
local_group_size[1u] = 2;
503521
}
504522
}
505-
return local_group_size;
523+
524+
return {
525+
local_group_size[global_wg_size_desc[0].first],
526+
local_group_size[global_wg_size_desc[1].first],
527+
local_group_size[global_wg_size_desc[2].first]};
506528
}
507529

508530
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {

0 commit comments

Comments
 (0)