Replies: 1 comment
-
This largely depends on your hardware and the quantization that you use. Run
The ideal option is to perform sampling on the GPU - some discussion here: #5214 |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've been looking into why the lookahead decoding implementation added in #4207 doesn't seem to be quite as fast as the paper advertises.
I've discovered that the lookahead decoding implementation requires ~N times the number of sampling evaluations and thus is more penalized by slow sampling. In particular, the crux of the improvement seems to take advantage of very cheap sampling compared to slower sequential evaluation. That is, (if I understand the paper correctly) sampling an additional O(W) tokens and incurring an O(W) additional sampling cost + O(W) batch decode is cheaper than the cost of sequential O(W) batch decodes.
It seems that the n-gram management cost is largely negligible from what I can tell.
This led to finding two optimizations that luckily also improve performance outside of lookahead decoding:
After these two PRs, I am able to see lookahead decoding perform about ~5% faster with W=4, N=3, G=4, temp=0.0, however this is still quite a ways from the paper's advertised ~50%. Additionally, my evaluations seem to be also worse because the sampling performance is O(n_vocab) and I am testing Llama3 which has a larger vocabulary whereas the paper is testing Llama2, so my sampling cost is doubled but the batch evaluation performance is not any more expensive. This seems to be a caveat that isn't really mentioned in the paper.
Has anyone else looked into this? I'd like to open up this discussion to talk about possible performance improvements to the sampling code, in particular some ideas specific to sampling performance that I am thinking about right now:
temp=0.0
performs two passes over the logits, one to fetch the logits, one to calculate the max. We can instead perform one singular pass to double-ish the performance.llama_sample_token_with_rng
? We seem to take a vector of logits, wrap them in structs, apply temperature, and then unwrap them again to pass tostd::discrete_distribution
. This seems a bit roundabout.Additionally, I'd like to start looking into improving batch decode performance too. I am a little surprised at the performance penalty there (I'm seeing ~10ms/additional batch) however I haven't dug too deep yet and would love to hear other perspectives on potential improvements.
Beta Was this translation helpful? Give feedback.
All reactions