-
Notifications
You must be signed in to change notification settings - Fork 1.3k
feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143
Conversation
ce5b63c to
6984bdd
Compare
|
Thank you for this! |
|
I know you updated with main just 30 minutes ago, but I just merged a couple of PRs to main and I want to see how they interact with this so doing another one :) |
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR is in a pretty good state when almost all I have to comment on is that the documentation is too verbose 👌
| if offset == 0 { | ||
| self.kv_cache.reset(); | ||
| } | ||
| let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooh interesting. This most recent change actually removes all the speedup we got on metal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-quantized version lost some t/s as well, but not as bad. I assume the diff is due to how matmul and qmatmul perform on non-contiguous tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll benchmark with/without .contiguous() on CPU and CUDA (with knowledge that continguous improves Metal performance). If different devices need different behavior, I'll add device-specific dispatch inside the KvCache implementation to handle .contiguous() automatically based on device type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If Cuda is unaffected I'd prefer, if possible, to keep this simple and instead improve non-contiguous matmul/qmatmul.
We're updating metal matmul very soon anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most recent tests actually showed CUDA benefit to contiguous before saving on the quantized version. No difference on CPUs. I tested three code variations
No Contiguous
let (k, v) = self.kv_cache.append(&k, &v)?;
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;
Append Contiguous
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;
+Repeat_kv Contig
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
On CPU and GPU for Quantized and Full Qwen3-0.6B 8_0
| Example | Features | Model | Sample Length | Configuration | Tokens Generated | Speed (token/s) |
|---|---|---|---|---|---|---|
| qwen | - | 3-0.6b | 100 | no contiguous | 100 | 7.48 |
| qwen | - | 3-0.6b | 100 | append contiguous | 100 | 7.51 |
| qwen | - | 3-0.6b | 100 | +repeat_kv contig | 100 | 7.45 |
| Example | Features | Model | Sample Length | Configuration | Tokens Generated | Speed (token/s) |
|---|---|---|---|---|---|---|
| qwen | cuda | 3-0.6b | 1000 | no contiguous | 1000 | 114.00 |
| qwen | cuda | 3-0.6b | 1000 | append contiguous | 1000 | 109.38 |
| qwen | cuda | 3-0.6b | 1000 | +repeat_kv contig | 1000 | 113.76 |
| Example | Features | Model | Sample Length | Configuration | Tokens Generated | Speed (token/s) |
|---|---|---|---|---|---|---|
| quantized-qwen3 | - | 0.6b | 100 | no contiguous | 94 | 34.54 |
| quantized-qwen3 | - | 0.6b | 100 | append contiguous | 94 | 34.54 |
| quantized-qwen3 | - | 0.6b | 100 | +repeat_kv contig | 94 | 35.31 |
| Example | Features | Model | Sample Length | Configuration | Tokens Generated | Speed (token/s) |
|---|---|---|---|---|---|---|
| quantized-qwen3 | cuda | 0.6b | 1000 | no contiguous | 962 | 62.96 |
| quantized-qwen3 | cuda | 0.6b | 1000 | append contiguous | 962 | 100.43 |
| quantized-qwen3 | cuda | 0.6b | 1000 | +repeat_kv contig | 962 | 100.79 |
This suggests to me that there is no noticeable difference except in concatenating quantized k,v tensors.
Do you think I should look into that as a separate PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
Do you think I should look into that as a separate PR?
We'll certainly note it for later!
But it looks like your +repeat_kv contig is consistently good across the board, so let's go with that approach for now? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I'll do that.
remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]>
remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]>
remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]>
consolidate tests Co-authored-by: ivarflakstad <[email protected]>
Large improvements for kv_cache append quantized tensors when in contiguous layout
Since always using contiguous
after contiguous
after contiguous
contiguous called inside append
improves some devices but doesn't hurt others
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! Thank you 🙌
Add ConcatKvCache for 2-5x GPU speedup on autoregressive generation
Summary
Adds a new
ConcatKvCacheimplementation that usesTensor::catinstead ofslice_setfor KV-cache updates, providing 2-5x GPU performance improvements for autoregressive generation.Motivation
The standard
KvCacheuses pre-allocated buffers withslice_setupdates, which has suboptimal performance on GPU due to:In contrast,
Tensor::catuses optimized concatenation kernels with:This PR adds
ConcatKvCacheas a new option alongside the existingKvCache, allowing developers to choose the best implementation for their use case.Changes
1. Added
ConcatKvCachetocandle-nn/src/kv_cache.rsA new KV-cache implementation that:
Tensor::catfor append operations instead ofslice_setKvCacheKey features:
new(dim),append(k, v),reset()catkernels (PR Optimize the cat operation on contiguous tensors #1855)2. Updated Qwen3 to use
ConcatKvCacheModified
candle-transformers/src/models/qwen3.rs(3 lines changed):3. Qwen3 MoE automatically benefits
No changes needed to
qwen3_moe.rs- it importsQwen3Attentionand automatically inherits the performance improvement.3. Updated Quantized-Qwen3 to use
ConcatKvCacheSame updates as implemented in Qwen3.
Easy Migration Path
ConcatKvCacheis designed as a near drop-in replacement with identical runtime API:new(dim, max_len)new(dim)append(k, v)append(k, v)reset()reset()(Tensor, Tensor)(Tensor, Tensor)Migration effort per model: Change 3 lines, get 2-5x speedup.
Performance Results
Benchmarked on Qwen3-0.6B:
Hardware:
GPU Performance - Significant Improvement
Quantized Model Performance
Also tested on Quantized Qwen3-0.6B (8-bit) to verify the optimization works across model types:
Key insight: Speedup increases with sequence length, making this especially valuable for long-context applications.
Speedup Growth Pattern
The performance advantage grows with sequence length:
This is because
slice_set's overhead compounds as the cache grows (larger strides), whilecatmaintains efficient sequential access patterns.CPU Performance - Neutral
CPU performance is essentially unchanged, confirming this optimization specifically targets GPU bottlenecks.
Design Rationale
Why add
ConcatKvCacheinstead of modifyingKvCache?Different use cases have different optimal implementations:
KvCacheConcatKvCacheRotatingKvCacheScatteredKvCacheBy keeping both implementations, developers can choose the right tool for their specific hardware and use case.
Why is
catfaster on GPU?Both
KvCacheandConcatKvCacheuse the optimizedcopy2dkernel from PR #1855, but they feed it different parameters:ConcatKvCache(viacat_contiguous):KvCache(viaslice_set):The difference:
See
candle-core/src/tensor_cat.rsfor the optimizedcat_contiguousimplementation.When to Use Each Cache
Added documentation to
kv_cache.rs:ConcatKvCacheKvCacheRotatingKvCacheScatteredKvCacheTesting
Unit Tests
Added 4 comprehensive tests for
ConcatKvCache:Integration Tests
Benchmark Command
Related PRs
copy2dkernel (both caches benefit from this)Breaking Changes
None. This PR:
Checklist