|
6 | 6 | #param tile_size_k_vec |
7 | 7 | #param sub_tile_count |
8 | 8 |
|
| 9 | +// Note that this shader adopts similar algorithm with dp4a generation shader. |
| 10 | +// |
| 11 | +// This algorithm works to compute dot product of v with qk parallelly, by |
| 12 | +// processing on the head_size dimension at each step amongst tile_size_k_vec |
| 13 | +// threads, and utilizing the remaining threads in the workgroup to process |
| 14 | +// additional rows of |present_value| in parallel (such that the values in |
| 15 | +// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads |
| 16 | +// also reload |present_value| tile_size/sub_tile_count times to compute partial |
| 17 | +// dot products of other |present_value| rows in order to complete all tile_size |
| 18 | +// |present_value| rows in this workgroup and also reusing the values in |
| 19 | +// tile_qk. |
| 20 | +// |
| 21 | +// The difference with FlashAttentionDecodeQKTProgram is that the dot products |
| 22 | +// go through the rows (total_sequence_length) of |present_value| instead of |
| 23 | +// columns (head_size_vec). And each workgroup only calculate current |
| 24 | +// tile_size's dot products instead of iterating the whole row |
| 25 | +// |total_sequence_length|. That's why this shader is a split shader. The final |
| 26 | +// reduce will be done in FlashAttentionDecodeReduceProgram. |
| 27 | + |
| 28 | +// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx |
| 29 | +// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate |
| 30 | +// memory. The FlashAttentionDecodeQKT can be merged into split shader and do |
| 31 | +// the final softmax adjustment in the reduce shader. However, some issues are |
| 32 | +// met that when the total sequence length exceeds some value, the result will |
| 33 | +// become garbage. Since it can't be resolved in a short time, leave it as TODO |
| 34 | +// to fix it in future. |
| 35 | + |
9 | 36 | var<workgroup> tile_qk: array<present_value_element_t, tile_size>; |
10 | 37 | var<workgroup> tile_output: array<present_value_value_t, head_size_vec>; |
11 | 38 | var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>, sub_tile_count>; |
|
0 commit comments