Skip to content

Commit f95a3f7

Browse files
authored
[ET-VK] Better work group sizes for matmul (#13378)
## Context Currently `default_pick_local_wg_size()` (which internally calls `ComputeGraph::create_local_wg_size`) is used to select the local work group size for matrix multiplication ops. However, these functions currently bias the size of the local work group towards the largest dim of the global work group producing local wg sizes like ``` shader globalwg size localwg size =========== ===================== ==================== ============= linear_qga4w_tiled_texture3d_texture3d_texture2d_float {256, 29, 1} {32, 2, 1} 1487 matmul_naive_texture3d_float {29, 115, 32} {4, 2, 8} 712 ``` for matrix multiplication shaders. This behaviour was introduced in D64418632 / #6409. However, through experimental testing a "square" work group size of `{8, 8, 1}` works a lot better for matrix multiplication shaders. The theoretical analysis for this behaviour is that the local work group size determines the memory locations that need to be loaded to compute the overall work group. For a work group with size `{W, H, 1}` the data required to compute the output would be `W * OUTPUT_TILE_W` columns of the weight tensor and `H * OUTPUT_TILE_H` rows of the input tensor. Note that all work group items in the same W index will be requesting the same columns from the weight tensor, and all work group items in the same H index will be requesting the same rows from the input tensor. If `H==W`, then that "balances" the amount of data needed to loaded from each input tensor and may result in better data sharing behaviour among all work group items. Assuming `OUTPUT_TILE_W == OUTPUT_TILE_H == 1`, a local work group of size `{64, 1, 1}` would require 1 unique row from the input tensor an 64 unique columns to be loaded from the weight tensor, resulting in `(1 + 64) * K = 65K` elements to be loaded in total, where K is the size of the shared reduction dim. Conversely, a local work group of size `{8, 8, 1}` would require 8 unique rows / 8 unique columns resulting in only `(8 + 8) * K = 16K` unique elements to be loaded. This highlights the need to use dedicated logic to compute work group sizes for matrix multiplication shaders. ## Changes * Introduce `pick_hw_square_wg_size` * Use the new local work group size determination function for Quantized Linear, Matmul, and Linear Differential Revision: [D79813236](https://our.internmc.facebook.com/intern/diff/D79813236/)
1 parent 75b77a6 commit f95a3f7

File tree

5 files changed

+48
-6
lines changed

5 files changed

+48
-6
lines changed

backends/vulkan/runtime/graph/ops/impl/Common.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,27 @@ utils::uvec3 default_pick_local_wg_size(
3333
return graph->create_local_wg_size(global_workgroup_size);
3434
}
3535

36+
utils::uvec3 pick_hw_square_wg_size(
37+
ComputeGraph* graph,
38+
const vkapi::ShaderInfo& shader,
39+
const utils::uvec3& global_workgroup_size,
40+
const std::vector<ArgGroup>& args,
41+
const std::vector<ValueRef>& resize_args) {
42+
(void)graph;
43+
(void)shader;
44+
(void)args;
45+
(void)resize_args;
46+
// Some inactive invocations are okay; set 6 as the threshold to use the
47+
// a square wg size.
48+
if (global_workgroup_size[0u] >= 6 && global_workgroup_size[1u] >= 6) {
49+
return {8u, 8u, 1u};
50+
}
51+
// If width dim is sufficiently small, then bias towards height dim to reduce
52+
// the number of inactive invocations.
53+
if (global_workgroup_size[0u] < 6u) {
54+
return {4u, 16u, 1u};
55+
}
56+
return {16u, 4u, 1u};
57+
}
58+
3659
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,22 @@ utils::uvec3 default_pick_local_wg_size(
3636
const std::vector<ArgGroup>& args,
3737
const std::vector<ValueRef>& resize_args);
3838

39+
/**
40+
* Constructs a local work group size with the shape {W, H, 1}. The function
41+
* will try to set W == H == sqrt(num_invocations), where num_invocations is
42+
* typically 64. This configuration is good for ops like matrix multiplication
43+
* as it reduces the total volume of unique data that the entire work group
44+
* will need to read from input tensors in order to produce the output data.
45+
* To compute an output tile of {W, H, 1}, the work group will need to read
46+
* H unique rows = H * K unique elements from the input tensor and W unique cols
47+
* = W * K elements from the weight tensor, resulting in (W + H) * K unique
48+
* elements in total.
49+
*/
50+
utils::uvec3 pick_hw_square_wg_size(
51+
ComputeGraph* graph,
52+
const vkapi::ShaderInfo& shader,
53+
const utils::uvec3& global_workgroup_size,
54+
const std::vector<ArgGroup>& args,
55+
const std::vector<ValueRef>& resize_args);
56+
3957
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/impl/Linear.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ void add_addmm_naive_texture_node(
178178
graph,
179179
VK_KERNEL_FROM_STR(kernel_name),
180180
addmm_naive_texture_global_wg_size,
181-
default_pick_local_wg_size,
181+
pick_hw_square_wg_size,
182182
// Inputs and Outputs
183183
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
184184
// Shader params buffers
@@ -245,7 +245,7 @@ void add_addmm_naive_buffer_node(
245245
graph,
246246
VK_KERNEL_FROM_STR(kernel_name),
247247
addmm_naive_buffer_global_wg_size,
248-
default_pick_local_wg_size,
248+
pick_hw_square_wg_size,
249249
// Inputs and Outputs
250250
{{out, vkapi::kWrite}, {{mat1, mat2, self}, vkapi::kRead}},
251251
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/MatMul.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ void add_matmul_naive_buffer_node(
102102
graph,
103103
VK_KERNEL_FROM_STR(kernel_name),
104104
matmul_naive_buffer_global_wg_size,
105-
default_pick_local_wg_size,
105+
pick_hw_square_wg_size,
106106
// Inputs and Outputs
107107
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
108108
// Shader params buffers
@@ -158,7 +158,7 @@ void add_matmul_naive_texture3d_node(
158158
graph,
159159
pick_matmul_naive_texture3d_shader,
160160
default_pick_global_wg_size,
161-
default_pick_local_wg_size,
161+
pick_hw_square_wg_size,
162162
// Inputs and Outputs
163163
{{out, vkapi::kWrite}, {{mat1, mat2}, vkapi::kRead}},
164164
// Shader params buffers
@@ -273,7 +273,7 @@ void add_matmul_optimized_node(
273273
graph,
274274
pick_matmul_optimized_shader,
275275
matmul_optimized_global_wg_size,
276-
default_pick_local_wg_size,
276+
pick_hw_square_wg_size,
277277
// Inputs and Outputs
278278
{{out, vkapi::kWrite}, {{mat1_W_packed, mat2_packed}, vkapi::kRead}},
279279
// Shader params buffers

backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ utils::uvec3 linear_qga4w_local_wg_size(
158158
if (use_coop_algorithm) {
159159
return {64, 1, 1};
160160
} else {
161-
return graph->create_local_wg_size(global_workgroup_size);
161+
return pick_hw_square_wg_size(
162+
graph, shader, global_workgroup_size, args, resize_args);
162163
}
163164
}
164165

0 commit comments

Comments
 (0)