Skip to content

Commit 18f91e5

Browse files
authored
[webgpu] Flash attention for generation (microsoft#23808)
This PR adds the flash decoding support to optimization the generation speed when the total sequence length is large. Previously, when the total sequence length is big enough, the softmax and softmax * v shaders will become the bottleneck since it only uses limited gpu cores. In this changes, we add the flash decoding support to split the present key/value based on the total sequence length, then do reduce to get the final result. On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for phi4 static kv cache On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static kv cache Side effect of this PR: It adds two extra buffers to store 1) metadata (max and exp_sum in each split), 2) the splited qkv results with shape [B, N, split_k, H], which increase the memory size. TODO: Ideally, there should only be two shaders, which can also reduce the intermediate memory. The computeQKT can be merged into split shader and do the final softmax adjustment in the reduce shader. However, I meet some issues that when the total sequence length exceeds some value, the result will become garbage. Since I can't resolve it in a short time, leave it in as TODO to fix it in future.
1 parent 554fb4a commit 18f91e5

File tree

2 files changed

+453
-46
lines changed

2 files changed

+453
-46
lines changed

0 commit comments

Comments
 (0)