Skip to content

Commit 60252cc

Browse files
feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation (#3143)
* add concat cache; use in qwen3 * update tradeoff desc; resolve unused var warning in concatKV test * update kv-cache concat method description * quant-qwen leverage concatKV; add 8_0 to example main * format 8_0 load * remove trailing , * trailing line * removed unnecessary contiguous calls * Update candle-nn/src/kv_cache.rs remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]> * Update candle-nn/src/kv_cache.rs remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]> * Update candle-nn/src/kv_cache.rs remove verbose kv-cache description Co-authored-by: ivarflakstad <[email protected]> * Update candle-nn/src/kv_cache.rs consolidate tests Co-authored-by: ivarflakstad <[email protected]> * Update candle-transformers/src/models/quantized_qwen3.rs Large improvements for kv_cache append quantized tensors when in contiguous layout * Update candle-nn/src/kv_cache.rs Since always using contiguous * Update candle-nn/src/kv_cache.rs after contiguous * Update candle-nn/src/kv_cache.rs after contiguous * Update candle-transformers/src/models/quantized_qwen3.rs contiguous called inside append * Update candle-transformers/src/models/quantized_qwen3.rs improves some devices but doesn't hurt others * make k and v continguous post repeat in qwen3 --------- Co-authored-by: ivarflakstad <[email protected]>
1 parent db08cc0 commit 60252cc

File tree

4 files changed

+284
-22
lines changed

4 files changed

+284
-22
lines changed

candle-examples/examples/quantized-qwen3/main.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial o
2121
enum Which {
2222
#[value(name = "0.6b")]
2323
W3_0_6b,
24+
#[value(name = "0.6b8_0")]
25+
W3_0_6b8_0,
2426
#[value(name = "1.7b")]
2527
W3_1_7b,
2628
#[value(name = "4b")]
@@ -103,6 +105,7 @@ impl Args {
103105
let api = hf_hub::api::sync::Api::new()?;
104106
let repo = match self.which {
105107
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
108+
Which::W3_0_6b8_0 => "Qwen/Qwen3-0.6B",
106109
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
107110
Which::W3_4b => "Qwen/Qwen3-4B",
108111
Which::W3_8b => "Qwen/Qwen3-8B",
@@ -122,6 +125,9 @@ impl Args {
122125
None => {
123126
let (repo, filename, revision) = match self.which {
124127
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
128+
Which::W3_0_6b8_0 => {
129+
("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q8_0.gguf", "main")
130+
}
125131
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
126132
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
127133
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),

candle-nn/src/kv_cache.rs

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,174 @@ impl ScatteredCacheBuilder {
631631
}
632632
}
633633

634+
/// KV-Cache using concatenation for append operations
635+
///
636+
/// This implementation uses `Tensor::cat` instead of `slice_set` for updates,
637+
/// providing significant GPU performance improvements for autoregressive generation.
638+
///
639+
/// # When to Use
640+
///
641+
/// **Recommended for:**
642+
/// - GPU inference (CUDA, Metal)
643+
/// - Autoregressive generation (token-by-token decoding)
644+
///
645+
/// **Use `KvCache` instead for:**
646+
/// - CPU-only inference
647+
/// - When you need fixed memory allocation upfront
648+
///
649+
/// # Example
650+
///
651+
/// ```ignore
652+
/// use candle_nn::kv_cache::ConcatKvCache;
653+
///
654+
/// let mut cache = ConcatKvCache::new(2); // dim=2 for sequence dimension
655+
///
656+
/// // First token (prefill)
657+
/// let k1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
658+
/// let v1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
659+
/// let (k, v) = cache.append(&k1, &v1)?;
660+
///
661+
/// // Subsequent tokens (decode)
662+
/// let k_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
663+
/// let v_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
664+
/// let (k, v) = cache.append(&k_new, &v_new)?;
665+
/// ```
666+
#[derive(Debug, Clone)]
667+
pub struct ConcatKvCache {
668+
k: Option<Tensor>,
669+
v: Option<Tensor>,
670+
dim: usize,
671+
}
672+
673+
impl ConcatKvCache {
674+
/// Create a new empty concatenation-based KV-cache
675+
///
676+
/// # Arguments
677+
/// * `dim` - The dimension along which to concatenate
678+
/// - For attention with shape `[batch, heads, seq, head_dim]`, use `dim=2`
679+
/// - For attention with shape `[batch, seq, heads, head_dim]`, use `dim=1`
680+
///
681+
/// # Example
682+
/// ```ignore
683+
/// // For standard transformer attention: [B, H, S, D]
684+
/// let cache = ConcatKvCache::new(2);
685+
/// ```
686+
pub fn new(dim: usize) -> Self {
687+
Self {
688+
k: None,
689+
v: None,
690+
dim,
691+
}
692+
}
693+
694+
/// Get current sequence length in the cache
695+
///
696+
/// Returns 0 if the cache is empty.
697+
pub fn current_seq_len(&self) -> usize {
698+
self.k
699+
.as_ref()
700+
.and_then(|k| k.dims().get(self.dim).copied())
701+
.unwrap_or(0)
702+
}
703+
704+
/// Check if cache is empty
705+
pub fn is_empty(&self) -> bool {
706+
self.k.is_none()
707+
}
708+
709+
/// Get the concatenation dimension
710+
pub fn dim(&self) -> usize {
711+
self.dim
712+
}
713+
714+
/// Append key and value tensors to the cache
715+
///
716+
/// This is the core operation that uses optimized concatenation kernels.
717+
///
718+
/// # Arguments
719+
/// * `k` - Key tensor to append (shape: [..., seq_len, ...])
720+
/// * `v` - Value tensor to append (shape: [..., seq_len, ...])
721+
///
722+
/// # Returns
723+
/// Tuple of `(full_k, full_v)` containing all cached keys and values,
724+
/// including the newly appended data.
725+
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
726+
// Ensure inputs are contiguous for optimal concatenation performance
727+
let k = k.contiguous()?;
728+
let v = v.contiguous()?;
729+
// Update K cache using concatenation
730+
self.k = Some(match &self.k {
731+
None => k.clone(),
732+
Some(k_cache) => {
733+
// Concatenate along the sequence dimension
734+
// GPU kernel for cat is highly optimized:
735+
// - Fused allocation + copy
736+
// - Coalesced memory access
737+
// - Single kernel launch
738+
Tensor::cat(&[k_cache, &k], self.dim)?
739+
}
740+
});
741+
742+
// Update V cache using concatenation
743+
self.v = Some(match &self.v {
744+
None => v.clone(),
745+
Some(v_cache) => Tensor::cat(&[v_cache, &v], self.dim)?,
746+
});
747+
748+
Ok((
749+
self.k.as_ref().unwrap().clone(),
750+
self.v.as_ref().unwrap().clone(),
751+
))
752+
}
753+
754+
/// Reset the cache (clear all stored keys and values)
755+
///
756+
/// After calling this, `is_empty()` will return `true` and
757+
/// `current_seq_len()` will return 0.
758+
pub fn reset(&mut self) {
759+
self.k = None;
760+
self.v = None;
761+
}
762+
763+
/// Get reference to current K cache data
764+
///
765+
/// Returns `None` if the cache is empty.
766+
pub fn k(&self) -> Option<&Tensor> {
767+
self.k.as_ref()
768+
}
769+
770+
/// Get reference to current V cache data
771+
///
772+
/// Returns `None` if the cache is empty.
773+
pub fn v(&self) -> Option<&Tensor> {
774+
self.v.as_ref()
775+
}
776+
777+
/// Get mutable reference to K cache data
778+
///
779+
/// Returns `None` if the cache is empty.
780+
pub fn k_mut(&mut self) -> Option<&mut Tensor> {
781+
self.k.as_mut()
782+
}
783+
784+
/// Get mutable reference to V cache data
785+
///
786+
/// Returns `None` if the cache is empty.
787+
pub fn v_mut(&mut self) -> Option<&mut Tensor> {
788+
self.v.as_mut()
789+
}
790+
791+
/// Get owned K and V tensors, consuming the cache
792+
///
793+
/// Returns `None` if the cache is empty.
794+
pub fn into_inner(self) -> Option<(Tensor, Tensor)> {
795+
match (self.k, self.v) {
796+
(Some(k), Some(v)) => Some((k, v)),
797+
_ => None,
798+
}
799+
}
800+
}
801+
634802
#[cfg(test)]
635803
mod tests {
636804
use super::*;
@@ -717,4 +885,102 @@ mod tests {
717885

718886
Ok(())
719887
}
888+
889+
#[test]
890+
fn test_concat_cache_basic() -> Result<()> {
891+
let device = Device::Cpu;
892+
let mut cache = ConcatKvCache::new(2);
893+
894+
assert!(cache.is_empty());
895+
assert_eq!(cache.current_seq_len(), 0);
896+
897+
// First append
898+
let k1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
899+
let v1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
900+
let (k, v) = cache.append(&k1, &v1)?;
901+
902+
assert_eq!(k.dims(), &[1, 8, 3, 64]);
903+
assert_eq!(v.dims(), &[1, 8, 3, 64]);
904+
assert_eq!(cache.current_seq_len(), 3);
905+
assert!(!cache.is_empty());
906+
907+
// Second append
908+
let k2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
909+
let v2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
910+
let (k, v) = cache.append(&k2, &v2)?;
911+
912+
assert_eq!(k.dims(), &[1, 8, 5, 64]); // 3 + 2
913+
assert_eq!(v.dims(), &[1, 8, 5, 64]);
914+
assert_eq!(cache.current_seq_len(), 5);
915+
916+
Ok(())
917+
}
918+
919+
#[test]
920+
fn test_concat_cache_reset() -> Result<()> {
921+
let device = Device::Cpu;
922+
let mut cache = ConcatKvCache::new(2);
923+
924+
let k = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
925+
let v = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
926+
cache.append(&k, &v)?;
927+
928+
assert_eq!(cache.current_seq_len(), 10);
929+
930+
cache.reset();
931+
932+
assert!(cache.is_empty());
933+
assert_eq!(cache.current_seq_len(), 0);
934+
assert!(cache.k().is_none());
935+
assert!(cache.v().is_none());
936+
937+
Ok(())
938+
}
939+
940+
#[test]
941+
fn test_concat_cache_multiple_appends() -> Result<()> {
942+
let device = Device::Cpu;
943+
let mut cache = ConcatKvCache::new(2);
944+
945+
// Simulate autoregressive generation
946+
let k_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
947+
let v_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
948+
cache.append(&k_prefill, &v_prefill)?;
949+
950+
assert_eq!(cache.current_seq_len(), 10);
951+
952+
// Decode phase: append one token at a time
953+
for i in 1..=5 {
954+
let k_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
955+
let v_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
956+
let (k, v) = cache.append(&k_token, &v_token)?;
957+
assert_eq!(k.dims()[2], 10 + i);
958+
assert_eq!(v.dims()[2], 10 + i);
959+
}
960+
961+
assert_eq!(cache.current_seq_len(), 15);
962+
963+
Ok(())
964+
}
965+
966+
#[test]
967+
fn test_concat_cache_different_dim() -> Result<()> {
968+
let device = Device::Cpu;
969+
let mut cache = ConcatKvCache::new(1); // Concatenate on dim 1 instead of 2
970+
971+
let k1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
972+
let v1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
973+
let (k, _v) = cache.append(&k1, &v1)?;
974+
975+
assert_eq!(k.dims(), &[1, 3, 8, 64]);
976+
977+
let k2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
978+
let v2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
979+
let (k, _v) = cache.append(&k2, &v2)?;
980+
981+
assert_eq!(k.dims(), &[1, 5, 8, 64]); // Concatenated on dim 1
982+
assert_eq!(cache.current_seq_len(), 5);
983+
984+
Ok(())
985+
}
720986
}

candle-transformers/src/models/quantized_qwen3.rs

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::with_tracing::QMatMul;
1010
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
1111
use candle::quantized::{gguf_file, QTensor};
1212
use candle::{DType, Device, Result, Tensor};
13-
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
13+
use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
1414
use std::io::{Read, Seek};
1515
use std::sync::Arc;
1616

@@ -136,7 +136,7 @@ struct AttentionWeights {
136136
num_kv_groups: usize,
137137
head_dim: usize,
138138
rotary_emb: Arc<RotaryEmbedding>,
139-
kv_cache: KvCache,
139+
kv_cache: ConcatKvCache,
140140
span_attn: tracing::Span,
141141
}
142142

@@ -160,9 +160,7 @@ impl AttentionWeights {
160160
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
161161
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
162162

163-
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
164-
// The cache will grow in chunks of 512 tokens when needed.
165-
let kv_cache = KvCache::new(2, 512);
163+
let kv_cache = ConcatKvCache::new(2);
166164

167165
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
168166

@@ -211,15 +209,7 @@ impl AttentionWeights {
211209

212210
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
213211

214-
// Reset KV cache if we're at the first position
215-
if offset == 0 {
216-
self.kv_cache.reset();
217-
}
218-
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
219-
220-
// Make tensor contiguous to avoid some strided copies
221-
let k = k.contiguous()?;
222-
let v = v.contiguous()?;
212+
let (k, v) = self.kv_cache.append(&k, &v)?;
223213

224214
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
225215
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;

candle-transformers/src/models/qwen3.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::{
33
utils::repeat_kv,
44
};
55
use candle::{DType, Device, Module, Result, Tensor};
6-
use candle_nn::{kv_cache::KvCache, Activation, VarBuilder};
6+
use candle_nn::{kv_cache::ConcatKvCache, Activation, VarBuilder};
77
use std::sync::Arc;
88

99
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
@@ -108,7 +108,7 @@ pub(crate) struct Qwen3Attention {
108108
hidden_size: usize,
109109
// utils
110110
rotary_emb: Arc<Qwen3RotaryEmbedding>,
111-
kv_cache: KvCache,
111+
kv_cache: ConcatKvCache,
112112
}
113113

114114
impl Qwen3Attention {
@@ -157,9 +157,9 @@ impl Qwen3Attention {
157157
// Necessary because the hidden_size in the config isn't always accurate
158158
let hidden_size = head_dim * cfg.num_attention_heads;
159159

160-
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
161-
// The cache will grow in chunks of 512 tokens when needed.
162-
let kv_cache = KvCache::new(2, 512);
160+
// dim=2 because we concatenate along the sequence dimension
161+
// For tensors of shape [batch, heads, seq, head_dim]
162+
let kv_cache = ConcatKvCache::new(2);
163163

164164
Ok(Self {
165165
q_proj,
@@ -214,11 +214,11 @@ impl Qwen3Attention {
214214
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
215215

216216
// 5. Accumulate KV cache
217-
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
217+
let (k, v) = self.kv_cache.append(&k, &v)?;
218218

219219
// 6. GQA repeat_kv
220-
let k = repeat_kv(k, self.num_kv_groups)?;
221-
let v = repeat_kv(v, self.num_kv_groups)?;
220+
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
221+
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
222222

223223
// 7. Attention score
224224
let scale = 1.0 / (self.head_dim as f64).sqrt();

0 commit comments

Comments
 (0)