Skip to content

Support Review of ConcatKvCache (#3143) and Plan for Future Adoption #3181

@DrJesseGlass

Description

@DrJesseGlass

Summary

PR #3143 introduces ConcatKvCache which provides 2-5x GPU speedup for autoregressive generation with no breaking changes to the API. However, it hasn't been reviewed yet.

This issue aims to:

  1. Raise visibility of feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143 for maintainer review
  2. Document the systemic need for this improvement across transformers
  3. Propose a path forward once feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143 is merged

Why This Matters Now

Multiple transformer implementations (including the new SmolLM3 in #3180) currently use workarounds that ConcatKvCache would eliminate:

// Current workaround in Llama, Mistral, Phi, Qwen, SmolLM3, etc:
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;

// With ConcatKvCache - clean and faster:
let (k, v) = self.kv_cache.append(&k, &v)?;

This pattern appears across the codebase, indicating a systemic issue that ConcatKvCache would solve.

Background: PR #3143

What It Provides

  • 2-5x GPU speedup for autoregressive generation
  • No breaking changes - API-compatible with existing KvCache
  • Equal or better performance on all backends (GPU, CPU, WASM)
  • Cleaner code - eliminates .contiguous() workarounds

Current Status

  • Awaiting review - not yet merged
  • Implementation complete
  • Benchmarks show clear benefits
  • Needs maintainer attention

The Broader Problem

Current State

Every transformer implementation manually works around KV-cache limitations:

Current pattern:

let kv_cache = KvCache::new(2, 512);
// ...
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;

With ConcatKvCache:

- let kv_cache = KvCache::new(2, 512);
+ let kv_cache = ConcatKvCache::new(2);  // Simpler API - no max_seq_len
  // ...
- let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
+ let (k, v) = self.kv_cache.append(&k, &v)?;  // No contiguous() needed

Additional benefits:

  • Simpler constructor (no max sequence length parameter)
  • Built-in reset() method useful for generation applications
  • 2-5x faster on GPU

Models affected:

Models that would benefit:

  • Llama family (llama.rs)
  • Mistral family (mistral.rs)
  • Phi family (phi.rs, phi3.rs)
  • Gemma family (gemma.rs)
  • All other transformers using KV-cache

This code pattern across the codebase suggests the underlying abstraction could be improved.

Performance Impact

Users running transformers on GPU are getting 2-5x slower performance than they could be getting, simply because we're using the legacy KvCache implementation.

Proposal: Two-Phase Approach

Phase 1: Get #3143 Reviewed and Merged (Immediate)

  1. Request maintainer review of feat(candle-nn) ConcatKvCache for 2-5x GPU speedup on autoregressive generation #3143
  2. Validate benchmarks across different models
  3. Test integration with existing transformers
  4. Merge once validated

Phase 2: Discuss Adoption Strategy (After #3143 Merges)

Once #3143 is merged and proven stable, we should discuss how to adopt it across the codebase:

Questions for Discussion:

  1. Migration strategy - Should we batch-update all models or migrate incrementally?
  2. Deprecation path - How do we handle the API change (constructor signature)?
  3. Timeline - What's a reasonable timeline for migration?
  4. Compatibility - Do we need any compatibility layer, or is a clean break acceptable?

Note: ConcatKvCache has a different API:

// Old API
let kv_cache = KvCache::new(num_layers, max_seq_len);

// New API (simpler - no max_seq_len parameter)
let kv_cache = ConcatKvCache::new(num_layers);

This isn't a drop-in replacement, so migration will require updating each transformer's initialization code. However, the benefits seem worth it:

  • 2-5x GPU speedup
  • Simpler API (no max_seq_len)
  • Built-in reset() method (useful for generation)
  • No .contiguous() workarounds needed

I'm happy to help test integration approaches and contribute to the migration effort once we have consensus on the strategy.

Why This Needs Attention

User Impact

  • Users are getting suboptimal performance (2-5x slower on GPU)
  • No indication that performance could be better
  • Workaround is invisible to users but affects everyone

Developer Impact

  • Code duplication across models
  • Current Transformers violate DRY principles

Request for Action

Short Term

Can a maintainer review PR #3143?

  • The implementation looks solid
  • Would benefit multiple existing and new PRs

Medium Term (After Merge)

Make ConcatKvCache the default

  • Via type alias for compatibility
  • Update documentation
  • Plan migration for existing models

Clean up workarounds

  • Update transformers systematically
  • Remove .contiguous() calls
  • Improve code quality

Long Term

  • ConcatKvCache becomes standard approach
  • New transformers get optimal performance by default
  • Cleaner, more maintainable codebase

The main work is systematic cleanup across transformer implementations, which can be done incrementally.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions