Commit f95a3f7
authored
[ET-VK] Better work group sizes for matmul (#13378)
## Context
Currently `default_pick_local_wg_size()` (which internally calls `ComputeGraph::create_local_wg_size`) is used to select the local work group size for matrix multiplication ops. However, these functions currently bias the size of the local work group towards the largest dim of the global work group producing local wg sizes like
```
shader globalwg size localwg size
=========== ===================== ==================== =============
linear_qga4w_tiled_texture3d_texture3d_texture2d_float {256, 29, 1} {32, 2, 1} 1487
matmul_naive_texture3d_float {29, 115, 32} {4, 2, 8} 712
```
for matrix multiplication shaders. This behaviour was introduced in D64418632 / #6409.
However, through experimental testing a "square" work group size of `{8, 8, 1}` works a lot better for matrix multiplication shaders. The theoretical analysis for this behaviour is that the local work group size determines the memory locations that need to be loaded to compute the overall work group. For a work group with size `{W, H, 1}` the data required to compute the output would be `W * OUTPUT_TILE_W` columns of the weight tensor and `H * OUTPUT_TILE_H` rows of the input tensor. Note that all work group items in the same W index will be requesting the same columns from the weight tensor, and all work group items in the same H index will be requesting the same rows from the input tensor.
If `H==W`, then that "balances" the amount of data needed to loaded from each input tensor and may result in better data sharing behaviour among all work group items. Assuming `OUTPUT_TILE_W == OUTPUT_TILE_H == 1`, a local work group of size `{64, 1, 1}` would require 1 unique row from the input tensor an 64 unique columns to be loaded from the weight tensor, resulting in `(1 + 64) * K = 65K` elements to be loaded in total, where K is the size of the shared reduction dim. Conversely, a local work group of size `{8, 8, 1}` would require 8 unique rows / 8 unique columns resulting in only `(8 + 8) * K = 16K` unique elements to be loaded.
This highlights the need to use dedicated logic to compute work group sizes for matrix multiplication shaders.
## Changes
* Introduce `pick_hw_square_wg_size`
* Use the new local work group size determination function for Quantized Linear, Matmul, and Linear
Differential Revision: [D79813236](https://our.internmc.facebook.com/intern/diff/D79813236/)1 parent 75b77a6 commit f95a3f7
File tree
5 files changed
+48
-6
lines changed- backends/vulkan/runtime/graph/ops/impl
5 files changed
+48
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
33 | 33 | | |
34 | 34 | | |
35 | 35 | | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
36 | 59 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
39 | 57 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
178 | 178 | | |
179 | 179 | | |
180 | 180 | | |
181 | | - | |
| 181 | + | |
182 | 182 | | |
183 | 183 | | |
184 | 184 | | |
| |||
245 | 245 | | |
246 | 246 | | |
247 | 247 | | |
248 | | - | |
| 248 | + | |
249 | 249 | | |
250 | 250 | | |
251 | 251 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
102 | 102 | | |
103 | 103 | | |
104 | 104 | | |
105 | | - | |
| 105 | + | |
106 | 106 | | |
107 | 107 | | |
108 | 108 | | |
| |||
158 | 158 | | |
159 | 159 | | |
160 | 160 | | |
161 | | - | |
| 161 | + | |
162 | 162 | | |
163 | 163 | | |
164 | 164 | | |
| |||
273 | 273 | | |
274 | 274 | | |
275 | 275 | | |
276 | | - | |
| 276 | + | |
277 | 277 | | |
278 | 278 | | |
279 | 279 | | |
| |||
Lines changed: 2 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
158 | 158 | | |
159 | 159 | | |
160 | 160 | | |
161 | | - | |
| 161 | + | |
| 162 | + | |
162 | 163 | | |
163 | 164 | | |
164 | 165 | | |
| |||
0 commit comments