Skip to content

Commit 5746ba9

Browse files
authored
[webgpu] Add back missing code comments for flash decoding (microsoft#25879)
Restore accidentally removed comments when using WGSL template.
1 parent 69ec7b1 commit 5746ba9

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,29 @@
66
#param tile_size_k_vec
77
#param sub_tile_count
88

9+
// Note that this shader adopts similar algorithm with dp4a generation shader.
10+
//
11+
// This algorithm works to compute dot product of keys with queries parallelly,
12+
// by processing on the k (head_size) dimension at each step amongst
13+
// tile_size_k_vec threads, and utilizing the remaining threads in the workgroup
14+
// to process additional rows of |present_key| in parallel (such that the values
15+
// in shared memory (tile_q) for |q| can be reused). For each load of q, the
16+
// tile_size_k_vec threads also reload |present_key| tile_size/sub_tile_count
17+
// times to compute partial dot products of other |present_key| rows in order to
18+
// complete all tile_size |present_key| rows in this workgroup and also reusing
19+
// the loaded in register values of |q|.
20+
21+
// 1. Each workgroup processes one row of |q| and tile_size rows of |present_key|
22+
//
23+
// 2. Computation Process:
24+
// - Reads [tile_size][tile_size_k_vec] block of |present_key| data at a time
25+
// - Each thread within workgroup computes dot products of 4 A*B elements
26+
// since each k represents 4 elements of |present_key|
27+
// - Stores intermediate results in shared memory (inner_qk_values)
28+
// - Iterates through columns (head_size_vec) accumulating results in
29+
// inner_qk_values
30+
// - Performs final reduction sum in inner_qk_values for output
31+
932
var<workgroup> tile_q: array<q_value_t, tile_size_k_vec>;
1033
var<workgroup> inner_qk_values: array<array<q_element_t, tile_size_k_vec>, tile_size>;
1134
var<workgroup> tile_qk: array<q_element_t, tile_size>;

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,33 @@
66
#param tile_size_k_vec
77
#param sub_tile_count
88

9+
// Note that this shader adopts similar algorithm with dp4a generation shader.
10+
//
11+
// This algorithm works to compute dot product of v with qk parallelly, by
12+
// processing on the head_size dimension at each step amongst tile_size_k_vec
13+
// threads, and utilizing the remaining threads in the workgroup to process
14+
// additional rows of |present_value| in parallel (such that the values in
15+
// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads
16+
// also reload |present_value| tile_size/sub_tile_count times to compute partial
17+
// dot products of other |present_value| rows in order to complete all tile_size
18+
// |present_value| rows in this workgroup and also reusing the values in
19+
// tile_qk.
20+
//
21+
// The difference with FlashAttentionDecodeQKTProgram is that the dot products
22+
// go through the rows (total_sequence_length) of |present_value| instead of
23+
// columns (head_size_vec). And each workgroup only calculate current
24+
// tile_size's dot products instead of iterating the whole row
25+
// |total_sequence_length|. That's why this shader is a split shader. The final
26+
// reduce will be done in FlashAttentionDecodeReduceProgram.
27+
28+
// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx
29+
// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate
30+
// memory. The FlashAttentionDecodeQKT can be merged into split shader and do
31+
// the final softmax adjustment in the reduce shader. However, some issues are
32+
// met that when the total sequence length exceeds some value, the result will
33+
// become garbage. Since it can't be resolved in a short time, leave it as TODO
34+
// to fix it in future.
35+
936
var<workgroup> tile_qk: array<present_value_element_t, tile_size>;
1037
var<workgroup> tile_output: array<present_value_value_t, head_size_vec>;
1138
var<workgroup> qkv_values: array<array<present_value_value_t, tile_size_k_vec>, sub_tile_count>;

onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33

44
#param tile_size
55

6+
// Inputs are splits of the GQA output, split into num_total_seq_length_tiles
7+
// rows. This shader needs to add these splits across the row dimension to
8+
// arrive at the final result. The column is head size wide. The reduction
9+
// achieves maximum parallelization by splitting this task first into tile_size
10+
// columns that each workgroup is responsible for. Then within each workgroup
11+
// the task of summation over the num_total_seq_length_tile for the tile_size
12+
// columns is further split in two ways. First across the row dimension to have
13+
// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE
14+
// rows. Then across the column dimension where each thread is responsible for 1
15+
// column of the TILE_SIZE columns the workgroup is responsible for.
16+
617
var<workgroup> tile_input: array<array<output_value_t, tile_size>, tile_size>;
718

819
$MAIN {

0 commit comments

Comments
 (0)