You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR adds the flash decoding support to optimization the generation
speed when the total sequence length is large. Previously, when the
total sequence length is big enough, the softmax and softmax * v shaders
will become the bottleneck since it only uses limited gpu cores. In this
changes, we add the flash decoding support to split the present
key/value based on the total sequence length, then do reduce to get the
final result.
On NV RTX 2000 Ada, the TPS becomes 41.4 from 34.4 for 1K tokens for
phi4 static kv cache
On Meteor Lake, the TPS becomes 19 from 16 for 1K tokens for phi4 static
kv cache
Side effect of this PR:
It adds two extra buffers to store 1) metadata (max and exp_sum in each
split), 2) the splited qkv results with shape [B, N, split_k, H], which
increase the memory size.
TODO:
Ideally, there should only be two shaders, which can also reduce the
intermediate memory. The computeQKT can be merged into split shader and
do the final softmax adjustment in the reduce shader. However, I meet
some issues that when the total sequence length exceeds some value, the
result will become garbage. Since I can't resolve it in a short time,
leave it in as TODO to fix it in future.
0 commit comments