Skip to content

Commit eaf3bf9

Browse files
trivedivivekfacebook-github-bot
authored andcommitted
Calculating axis's local wg size based on global workload and making it as close as possible to warp size of 32. (#6409)
Summary: This diff changes the local workgroup size calculation logic in the Vulkan backend of Executorch. The workgroup size of the largest axis is kept largest so workgroups are better occupied. The workgroup size is calculated based on the warp size of 32. When kernel is 2 dimensional largest axis is kept close to warp size it, so threads in the same warp Read / Write to consecutive memory locations, thus improving performance. Reviewed By: SS-JIA Differential Revision: D64418632
1 parent 8c96805 commit eaf3bf9

File tree

1 file changed

+31
-10
lines changed

1 file changed

+31
-10
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -485,24 +485,45 @@ utils::uvec3 ComputeGraph::create_local_wg_size(
485485
return config_.local_wg_size_override;
486486
}
487487

488-
utils::uvec3 local_group_size = {4, 4, 4};
488+
// array containing axis index and global workgroup size
489+
std::pair<uint32_t, uint32_t> global_wg_size_desc[] = {
490+
{0u, global_wg_size[0]},
491+
{1u, global_wg_size[1]},
492+
{2u, global_wg_size[2]}};
493+
494+
// sort the global workgroup size in descending order
495+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
496+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
497+
}
498+
if (global_wg_size_desc[1].second < global_wg_size_desc[2].second) {
499+
std::swap(global_wg_size_desc[1], global_wg_size_desc[2]);
500+
}
501+
if (global_wg_size_desc[0].second < global_wg_size_desc[1].second) {
502+
std::swap(global_wg_size_desc[0], global_wg_size_desc[1]);
503+
}
489504

490-
if (global_wg_size[2u] == 1) {
491-
if (global_wg_size[1u] == 1) {
505+
utils::uvec3 local_group_size = {
506+
8,
507+
std::max(1u, std::min(4u, global_wg_size_desc[1].second)),
508+
std::max(1u, std::min(2u, global_wg_size_desc[2].second))};
509+
510+
if (global_wg_size_desc[2u].second == 1) {
511+
if (global_wg_size_desc[1u].second == 1) {
492512
local_group_size[0u] = 64;
493513
local_group_size[1u] = 1;
494-
local_group_size[2u] = 1;
495-
} else if (global_wg_size[1u] < 8) {
514+
} else if (global_wg_size_desc[1u].second % 4 == 0) {
496515
local_group_size[0u] = 16;
497516
local_group_size[1u] = 4;
498-
local_group_size[2u] = 1;
499517
} else {
500-
local_group_size[0u] = 8;
501-
local_group_size[1u] = 8;
502-
local_group_size[2u] = 1;
518+
local_group_size[0u] = 32;
519+
local_group_size[1u] = 2;
503520
}
504521
}
505-
return local_group_size;
522+
523+
return {
524+
local_group_size[global_wg_size_desc[0].first],
525+
local_group_size[global_wg_size_desc[1].first],
526+
local_group_size[global_wg_size_desc[2].first]};
506527
}
507528

508529
utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) {

0 commit comments

Comments
 (0)