Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
a20d326
add concat cache; use in qwen3
DrJesseGlass Oct 21, 2025
5eeb857
update tradeoff desc; resolve unused var warning in concatKV test
DrJesseGlass Oct 21, 2025
8b9e6b2
update kv-cache concat method description
DrJesseGlass Oct 28, 2025
8205914
quant-qwen leverage concatKV; add 8_0 to example main
DrJesseGlass Oct 28, 2025
6984bdd
format 8_0 load
DrJesseGlass Oct 28, 2025
dc1a4bf
remove trailing ,
DrJesseGlass Oct 28, 2025
571ec7c
trailing line
DrJesseGlass Oct 28, 2025
c0dd10b
Merge branch 'main' into concate_cache/qwen3
DrJesseGlass Nov 11, 2025
b8ef99e
Merge branch 'main' into concate_cache/qwen3
ivarflakstad Nov 11, 2025
c930aa1
removed unnecessary contiguous calls
DrJesseGlass Nov 13, 2025
b17991b
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
dabd09d
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
73517e5
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
a1d39e4
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
f55cda1
Update candle-transformers/src/models/quantized_qwen3.rs
DrJesseGlass Nov 13, 2025
d6ffeb6
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
e00907a
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
14741e6
Update candle-nn/src/kv_cache.rs
DrJesseGlass Nov 13, 2025
36667c6
Update candle-transformers/src/models/quantized_qwen3.rs
DrJesseGlass Nov 13, 2025
c887900
Update candle-transformers/src/models/quantized_qwen3.rs
DrJesseGlass Nov 13, 2025
36ce602
make k and v continguous post repeat in qwen3
DrJesseGlass Nov 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions candle-examples/examples/quantized-qwen3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial o
enum Which {
#[value(name = "0.6b")]
W3_0_6b,
#[value(name = "0.6b8_0")]
W3_0_6b8_0,
#[value(name = "1.7b")]
W3_1_7b,
#[value(name = "4b")]
Expand Down Expand Up @@ -103,6 +105,7 @@ impl Args {
let api = hf_hub::api::sync::Api::new()?;
let repo = match self.which {
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
Which::W3_0_6b8_0 => "Qwen/Qwen3-0.6B",
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
Which::W3_4b => "Qwen/Qwen3-4B",
Which::W3_8b => "Qwen/Qwen3-8B",
Expand All @@ -122,6 +125,9 @@ impl Args {
None => {
let (repo, filename, revision) = match self.which {
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
Which::W3_0_6b8_0 => {
("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q8_0.gguf", "main")
}
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),
Expand Down
263 changes: 263 additions & 0 deletions candle-nn/src/kv_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,171 @@ impl ScatteredCacheBuilder {
}
}

/// KV-Cache using concatenation for append operations
///
/// This implementation uses `Tensor::cat` instead of `slice_set` for updates,
/// providing significant GPU performance improvements for autoregressive generation.
///
/// # When to Use
///
/// **Recommended for:**
/// - GPU inference (CUDA, Metal)
/// - Autoregressive generation (token-by-token decoding)
///
/// **Use `KvCache` instead for:**
/// - CPU-only inference
/// - When you need fixed memory allocation upfront
///
/// # Example
///
/// ```ignore
/// use candle_nn::kv_cache::ConcatKvCache;
///
/// let mut cache = ConcatKvCache::new(2); // dim=2 for sequence dimension
///
/// // First token (prefill)
/// let k1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
/// let v1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
/// let (k, v) = cache.append(&k1, &v1)?;
///
/// // Subsequent tokens (decode)
/// let k_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
/// let v_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
/// let (k, v) = cache.append(&k_new, &v_new)?;
/// ```
#[derive(Debug, Clone)]
pub struct ConcatKvCache {
k: Option<Tensor>,
v: Option<Tensor>,
dim: usize,
}

impl ConcatKvCache {
/// Create a new empty concatenation-based KV-cache
///
/// # Arguments
/// * `dim` - The dimension along which to concatenate
/// - For attention with shape `[batch, heads, seq, head_dim]`, use `dim=2`
/// - For attention with shape `[batch, seq, heads, head_dim]`, use `dim=1`
///
/// # Example
/// ```ignore
/// // For standard transformer attention: [B, H, S, D]
/// let cache = ConcatKvCache::new(2);
/// ```
pub fn new(dim: usize) -> Self {
Self {
k: None,
v: None,
dim,
}
}

/// Get current sequence length in the cache
///
/// Returns 0 if the cache is empty.
pub fn current_seq_len(&self) -> usize {
self.k
.as_ref()
.and_then(|k| k.dims().get(self.dim).copied())
.unwrap_or(0)
}

/// Check if cache is empty
pub fn is_empty(&self) -> bool {
self.k.is_none()
}

/// Get the concatenation dimension
pub fn dim(&self) -> usize {
self.dim
}

/// Append key and value tensors to the cache
///
/// This is the core operation that uses optimized concatenation kernels.
///
/// # Arguments
/// * `k` - Key tensor to append (shape: [..., seq_len, ...])
/// * `v` - Value tensor to append (shape: [..., seq_len, ...])
///
/// # Returns
/// Tuple of `(full_k, full_v)` containing all cached keys and values,
/// including the newly appended data.
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
// Update K cache using concatenation
self.k = Some(match &self.k {
None => k.clone(),
Some(k_cache) => {
// Concatenate along the sequence dimension
// GPU kernel for cat is highly optimized:
// - Fused allocation + copy
// - Coalesced memory access
// - Single kernel launch
Tensor::cat(&[k_cache, k], self.dim)?
}
});

// Update V cache using concatenation
self.v = Some(match &self.v {
None => v.clone(),
Some(v_cache) => Tensor::cat(&[v_cache, v], self.dim)?,
});

Ok((
self.k.as_ref().unwrap().clone(),
self.v.as_ref().unwrap().clone(),
))
}

/// Reset the cache (clear all stored keys and values)
///
/// After calling this, `is_empty()` will return `true` and
/// `current_seq_len()` will return 0.
pub fn reset(&mut self) {
self.k = None;
self.v = None;
}

/// Get reference to current K cache data
///
/// Returns `None` if the cache is empty.
pub fn k(&self) -> Option<&Tensor> {
self.k.as_ref()
}

/// Get reference to current V cache data
///
/// Returns `None` if the cache is empty.
pub fn v(&self) -> Option<&Tensor> {
self.v.as_ref()
}

/// Get mutable reference to K cache data
///
/// Returns `None` if the cache is empty.
pub fn k_mut(&mut self) -> Option<&mut Tensor> {
self.k.as_mut()
}

/// Get mutable reference to V cache data
///
/// Returns `None` if the cache is empty.
pub fn v_mut(&mut self) -> Option<&mut Tensor> {
self.v.as_mut()
}

/// Get owned K and V tensors, consuming the cache
///
/// Returns `None` if the cache is empty.
pub fn into_inner(self) -> Option<(Tensor, Tensor)> {
match (self.k, self.v) {
(Some(k), Some(v)) => Some((k, v)),
_ => None,
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -717,4 +882,102 @@ mod tests {

Ok(())
}

#[test]
fn test_concat_cache_basic() -> Result<()> {
let device = Device::Cpu;
let mut cache = ConcatKvCache::new(2);

assert!(cache.is_empty());
assert_eq!(cache.current_seq_len(), 0);

// First append
let k1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
let v1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
let (k, v) = cache.append(&k1, &v1)?;

assert_eq!(k.dims(), &[1, 8, 3, 64]);
assert_eq!(v.dims(), &[1, 8, 3, 64]);
assert_eq!(cache.current_seq_len(), 3);
assert!(!cache.is_empty());

// Second append
let k2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
let v2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
let (k, v) = cache.append(&k2, &v2)?;

assert_eq!(k.dims(), &[1, 8, 5, 64]); // 3 + 2
assert_eq!(v.dims(), &[1, 8, 5, 64]);
assert_eq!(cache.current_seq_len(), 5);

Ok(())
}

#[test]
fn test_concat_cache_reset() -> Result<()> {
let device = Device::Cpu;
let mut cache = ConcatKvCache::new(2);

let k = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
let v = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
cache.append(&k, &v)?;

assert_eq!(cache.current_seq_len(), 10);

cache.reset();

assert!(cache.is_empty());
assert_eq!(cache.current_seq_len(), 0);
assert!(cache.k().is_none());
assert!(cache.v().is_none());

Ok(())
}

#[test]
fn test_concat_cache_multiple_appends() -> Result<()> {
let device = Device::Cpu;
let mut cache = ConcatKvCache::new(2);

// Simulate autoregressive generation
let k_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
let v_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
cache.append(&k_prefill, &v_prefill)?;

assert_eq!(cache.current_seq_len(), 10);

// Decode phase: append one token at a time
for i in 1..=5 {
let k_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
let v_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
let (k, v) = cache.append(&k_token, &v_token)?;
assert_eq!(k.dims()[2], 10 + i);
assert_eq!(v.dims()[2], 10 + i);
}

assert_eq!(cache.current_seq_len(), 15);

Ok(())
}

#[test]
fn test_concat_cache_different_dim() -> Result<()> {
let device = Device::Cpu;
let mut cache = ConcatKvCache::new(1); // Concatenate on dim 1 instead of 2

let k1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
let v1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
let (k, _v) = cache.append(&k1, &v1)?;

assert_eq!(k.dims(), &[1, 3, 8, 64]);

let k2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
let v2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
let (k, _v) = cache.append(&k2, &v2)?;

assert_eq!(k.dims(), &[1, 5, 8, 64]); // Concatenated on dim 1
assert_eq!(cache.current_seq_len(), 5);

Ok(())
}
}
20 changes: 5 additions & 15 deletions candle-transformers/src/models/quantized_qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::with_tracing::QMatMul;
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
use candle::quantized::{gguf_file, QTensor};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
use candle_nn::{kv_cache::ConcatKvCache, Activation, Embedding, Module};
use std::io::{Read, Seek};
use std::sync::Arc;

Expand Down Expand Up @@ -136,7 +136,7 @@ struct AttentionWeights {
num_kv_groups: usize,
head_dim: usize,
rotary_emb: Arc<RotaryEmbedding>,
kv_cache: KvCache,
kv_cache: ConcatKvCache,
span_attn: tracing::Span,
}

Expand All @@ -160,9 +160,7 @@ impl AttentionWeights {
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;

// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
// The cache will grow in chunks of 512 tokens when needed.
let kv_cache = KvCache::new(2, 512);
let kv_cache = ConcatKvCache::new(2);

let span_attn = tracing::span!(tracing::Level::TRACE, "attn");

Expand Down Expand Up @@ -211,18 +209,10 @@ impl AttentionWeights {

let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;

// Reset KV cache if we're at the first position
if offset == 0 {
self.kv_cache.reset();
}
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooh interesting. This most recent change actually removes all the speedup we got on metal.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-quantized version lost some t/s as well, but not as bad. I assume the diff is due to how matmul and qmatmul perform on non-contiguous tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll benchmark with/without .contiguous() on CPU and CUDA (with knowledge that continguous improves Metal performance). If different devices need different behavior, I'll add device-specific dispatch inside the KvCache implementation to handle .contiguous() automatically based on device type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If Cuda is unaffected I'd prefer, if possible, to keep this simple and instead improve non-contiguous matmul/qmatmul.

We're updating metal matmul very soon anyway.

Copy link
Contributor Author

@DrJesseGlass DrJesseGlass Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most recent tests actually showed CUDA benefit to contiguous before saving on the quantized version. No difference on CPUs. I tested three code variations

No Contiguous

let (k, v) = self.kv_cache.append(&k, &v)?;
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;

Append Contiguous

let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;

+Repeat_kv Contig

let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;

On CPU and GPU for Quantized and Full Qwen3-0.6B 8_0

Example Features Model Sample Length Configuration Tokens Generated Speed (token/s)
qwen - 3-0.6b 100 no contiguous 100 7.48
qwen - 3-0.6b 100 append contiguous 100 7.51
qwen - 3-0.6b 100 +repeat_kv contig 100 7.45
Example Features Model Sample Length Configuration Tokens Generated Speed (token/s)
qwen cuda 3-0.6b 1000 no contiguous 1000 114.00
qwen cuda 3-0.6b 1000 append contiguous 1000 109.38
qwen cuda 3-0.6b 1000 +repeat_kv contig 1000 113.76
Example Features Model Sample Length Configuration Tokens Generated Speed (token/s)
quantized-qwen3 - 0.6b 100 no contiguous 94 34.54
quantized-qwen3 - 0.6b 100 append contiguous 94 34.54
quantized-qwen3 - 0.6b 100 +repeat_kv contig 94 35.31
Example Features Model Sample Length Configuration Tokens Generated Speed (token/s)
quantized-qwen3 cuda 0.6b 1000 no contiguous 962 62.96
quantized-qwen3 cuda 0.6b 1000 append contiguous 962 100.43
quantized-qwen3 cuda 0.6b 1000 +repeat_kv contig 962 100.79

This suggests to me that there is no noticeable difference except in concatenating quantized k,v tensors.
Do you think I should look into that as a separate PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

Do you think I should look into that as a separate PR?

We'll certainly note it for later!
But it looks like your +repeat_kv contig is consistently good across the board, so let's go with that approach for now? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I'll do that.


// Make tensor contiguous to avoid some strided copies
let k = k.contiguous()?;
let v = v.contiguous()?;

let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
let k = repeat_kv(k, self.num_kv_groups)?;
let v = repeat_kv(v, self.num_kv_groups)?;

let scale = 1.0 / (self.head_dim as f64).sqrt();
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
Expand Down
12 changes: 6 additions & 6 deletions candle-transformers/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
utils::repeat_kv,
};
use candle::{DType, Device, Module, Result, Tensor};
use candle_nn::{kv_cache::KvCache, Activation, VarBuilder};
use candle_nn::{kv_cache::ConcatKvCache, Activation, VarBuilder};
use std::sync::Arc;

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
Expand Down Expand Up @@ -108,7 +108,7 @@ pub(crate) struct Qwen3Attention {
hidden_size: usize,
// utils
rotary_emb: Arc<Qwen3RotaryEmbedding>,
kv_cache: KvCache,
kv_cache: ConcatKvCache,
}

impl Qwen3Attention {
Expand Down Expand Up @@ -157,9 +157,9 @@ impl Qwen3Attention {
// Necessary because the hidden_size in the config isn't always accurate
let hidden_size = head_dim * cfg.num_attention_heads;

// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
// The cache will grow in chunks of 512 tokens when needed.
let kv_cache = KvCache::new(2, 512);
// dim=2 because we concatenate along the sequence dimension
// For tensors of shape [batch, heads, seq, head_dim]
let kv_cache = ConcatKvCache::new(2);

Ok(Self {
q_proj,
Expand Down Expand Up @@ -214,7 +214,7 @@ impl Qwen3Attention {
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;

// 5. Accumulate KV cache
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
let (k, v) = self.kv_cache.append(&k, &v)?;

// 6. GQA repeat_kv
let k = repeat_kv(k, self.num_kv_groups)?;
Expand Down