Quantized KV cache reduces memory footprint by 2-4x through low-precision storage (FP8/INT8/INT4) of cached key-value tensors, enabling larger batch sizes and longer context lengths with minimal quality degradation.
- Introduction and Motivation
- Theoretical Foundation
- Mathematical Formulation
- Quantization Schemes
- Implementation Details
- Per-Token vs Per-Tensor Quantization
- Quality-Memory Trade-offs
- Performance Analysis
- Integration with Serving Systems
- Benchmarks and Results
During LLM inference, the KV cache dominates memory usage:
Model: Llama-2-7B, Context: 2048 tokens, Batch: 32
Model weights: 14 GB (FP16)
KV cache: 16 GB (FP16) ← Larger than model!
Total: 30 GB
With quantized KV (INT8):
KV cache: 8 GB (2x reduction)
Total: 22 GB → 27% memory savings
Benefits:
- 2-4x memory reduction: Store more sequences or longer contexts
- Higher batch sizes: 2x memory → 2x throughput
- Minimal quality loss: < 1% degradation with INT8
- Compatible: Works with other optimizations
| Method | Memory Savings | Speed Impact | Quality Impact |
|---|---|---|---|
| Quantized KV | 2-4x | -5 to +10% | -0.5 to -2% |
| PagedAttention | 0% (efficiency) | 0% | 0% |
| Prefix Caching | Variable | Variable | 0% |
| Weight Quant | Model-dependent | +20-50% | -1 to -3% |
Key observation: KV cache doesn't need full FP16 precision.
Why?
- Attention computation is robust to quantization
- Softmax normalizes scores (reduces sensitivity)
- Values are aggregated (averaging reduces errors)
For attention: Attention(Q, K, V) = softmax(QK^T / √d) V
With quantized K', V':
Error ≈ ||Attention(Q, K', V') - Attention(Q, K, V)||
Bounded by:
||K' - K|| / √d + ||V' - V|| × softmax_values
Key insight: Errors in K and V partially cancel out!
Entropy of KV values:
- Most values cluster around 0
- Heavy tails are rare
- High redundancy → compressible
Quantization as lossy compression:
Original: 16 bits/value → 65,536 levels
INT8: 8 bits/value → 256 levels
INT4: 4 bits/value → 16 levels
Information loss (bits): 16 - 8 = 8 bits
But effective entropy often < 12 bits
→ Minimal perceptual loss
Forward (quantize):
x_quant = round((x - zero_point) / scale)
x_quant = clip(x_quant, q_min, q_max)
where:
scale = (x_max - x_min) / (q_max - q_min)
zero_point = q_min - round(x_min / scale)
Inverse (dequantize):
x_dequant = scale × (x_quant + zero_point)
Symmetric (zero_point = 0):
scale = max(|x_max|, |x_min|) / q_max
x_quant = round(x / scale)
Pros: Simpler, faster
Cons: Wastes range if data is skewed
Asymmetric (zero_point ≠ 0):
Uses full quantization range
Better for skewed distributions
Per-tensor: Single scale for entire tensor
K: (batch, heads, seq_len, head_dim)
scale_K: scalar
Per-token: Scale per sequence position
scale_K: (batch, heads, seq_len, 1)
Per-channel: Scale per head dimension
scale_K: (batch, heads, 1, head_dim)
Per-token-channel (best quality):
scale_K: (batch, heads, seq_len, head_dim // group_size)
Format: 1 sign + 4 exponent + 3 mantissa bits
Range: ±448 (sufficient for activations)
Special values: NaN, ±Inf
Hardware support: H100, MI300
Advantages:
✓ Hardware accelerated
✓ Better dynamic range than INT8
✓ Minimal quality loss
Disadvantages:
✗ Limited GPU support
✗ Only 2x compression
Format: Signed 8-bit integer [-128, 127]
Most common choice for KV quantization
Advantages:
✓ Wide hardware support
✓ 2x memory reduction
✓ < 1% quality loss
✓ Fast dequantization
Disadvantages:
✗ Requires calibration
✗ Outliers can hurt quality
Format: 4-bit integer [0, 15] or [-8, 7]
Pack 2 values per byte
Advantages:
✓ 4x memory reduction
✓ Huge memory savings
✓ Enables very long contexts
Disadvantages:
✗ 1-3% quality degradation
✗ Slower packing/unpacking
✗ Requires careful tuning
Mixed precision: Different precision for K and V
Often K needs more precision than V
K: INT8, V: INT4 → 3x compression
Outlier preservation: Keep important values in FP16
99% quantized, 1% outliers in FP16
Minimal overhead, much better quality
From /nexus/components/inference/kv_cache.py:
class QuantizedKVCache(NexusModule):
"""Quantized KV Cache for memory-efficient inference."""
SUPPORTED_QUANT_TYPES = {'fp8', 'int8', 'int4'}
def __init__(
self,
num_layers: int,
max_batch_size: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
quant_type: str = 'int8',
group_size: int = 128,
symmetric: bool = True
):
super().__init__()
if quant_type not in self.SUPPORTED_QUANT_TYPES:
raise ValueError(f"Unsupported quant_type: {quant_type}")
self.num_layers = num_layers
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.quant_type = quant_type
self.group_size = group_size
self.symmetric = symmetric
# Storage dtype
if quant_type in ['fp8', 'int8']:
self.storage_dtype = torch.int8
else: # int4: pack 2 values per byte
self.storage_dtype = torch.int8
# Will be allocated lazily
self.k_cache = None
self.v_cache = None
self.k_scales = None
self.v_scales = None
self.k_zeros = None # For asymmetric
self.v_zeros = None
self.seq_lens = None
self._allocated = Falsedef quantize(
self,
tensor: torch.Tensor,
return_params: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Quantize tensor to target dtype."""
original_shape = tensor.shape
flat = tensor.reshape(-1)
if self.quant_type == 'fp8':
# FP8 E4M3: range [-448, 448]
abs_max = flat.abs().max().clamp(min=1e-12)
scale = 448.0 / abs_max
quantized = (flat * scale).clamp(-448, 448)
quantized = (quantized / 448.0 * 127).to(torch.int8)
scale = abs_max / 127.0
elif self.quant_type == 'int8':
if self.symmetric:
abs_max = flat.abs().max().clamp(min=1e-12)
scale = abs_max / 127.0
quantized = (flat / scale).round().clamp(-127, 127).to(torch.int8)
zero_point = None
else:
min_val = flat.min()
max_val = flat.max()
scale = (max_val - min_val).clamp(min=1e-12) / 255.0
zero_point = (-min_val / scale).round().clamp(0, 255)
quantized = ((flat - min_val) / scale).round().clamp(0, 255).to(torch.int8)
else: # int4
if self.symmetric:
abs_max = flat.abs().max().clamp(min=1e-12)
scale = abs_max / 7.0
quantized = (flat / scale).round().clamp(-7, 7)
zero_point = None
else:
min_val = flat.min()
max_val = flat.max()
scale = (max_val - min_val).clamp(min=1e-12) / 15.0
zero_point = (-min_val / scale).round().clamp(0, 15)
quantized = ((flat - min_val) / scale).round().clamp(0, 15)
# Pack two INT4 values into one INT8
quantized = quantized.to(torch.int8)
if len(quantized) % 2 != 0:
quantized = torch.cat([quantized, torch.zeros(1, dtype=torch.int8, device=tensor.device)])
low = quantized[::2] & 0x0F
high = (quantized[1::2] & 0x0F) << 4
quantized = (low | high).to(torch.int8)
# Reshape
if self.quant_type == 'int4':
new_shape = list(original_shape)
new_shape[-1] = (new_shape[-1] + 1) // 2
quantized = quantized.reshape(new_shape)
else:
quantized = quantized.reshape(original_shape)
if return_params:
return quantized, scale, zero_point if not self.symmetric else None
return quantized, scale, Nonedef dequantize(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
original_dim: Optional[int] = None
) -> torch.Tensor:
"""Dequantize tensor back to FP16/FP32."""
if self.quant_type == 'int4':
# Unpack INT4 from INT8
original_shape = list(tensor.shape)
flat = tensor.reshape(-1)
low = (flat & 0x0F).to(torch.float32)
high = ((flat >> 4) & 0x0F).to(torch.float32)
unpacked = torch.zeros(len(flat) * 2, dtype=torch.float32, device=tensor.device)
unpacked[::2] = low
unpacked[1::2] = high
if self.symmetric:
unpacked = unpacked - 8 # Convert [0,15] to [-8,7]
original_shape[-1] = original_shape[-1] * 2
if original_dim is not None:
original_shape[-1] = original_dim
dequantized = unpacked[:torch.prod(torch.tensor(original_shape))].reshape(original_shape)
if self.symmetric:
dequantized = dequantized * scale
else:
dequantized = (dequantized - zero_point) * scale
elif self.quant_type == 'fp8':
dequantized = tensor.to(torch.float32) * scale
else: # int8
dequantized = tensor.to(torch.float32)
if self.symmetric:
dequantized = dequantized * scale
else:
dequantized = (dequantized - zero_point) * scale
return dequantizeddef update(
self,
layer_idx: int,
key: torch.Tensor,
value: torch.Tensor,
start_pos: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update cache with new KV (quantizing on write)."""
if not self._allocated:
self.allocate(device=key.device, dtype=key.dtype)
batch_size, _, seq_len, _ = key.shape
if start_pos is None:
start_pos = self.seq_lens[0].item()
# Quantize new KV
k_quant, k_scale, k_zero = self.quantize(key, return_params=True)
v_quant, v_scale, v_zero = self.quantize(value, return_params=True)
# Update cache
self.k_cache[layer_idx][:batch_size, :, start_pos:start_pos+seq_len, :] = k_quant
self.v_cache[layer_idx][:batch_size, :, start_pos:start_pos+seq_len, :] = v_quant
# Store scales
group_idx = start_pos // self.group_size
self.k_scales[layer_idx][:batch_size, :, group_idx] = k_scale
self.v_scales[layer_idx][:batch_size, :, group_idx] = v_scale
if not self.symmetric:
self.k_zeros[layer_idx][:batch_size, :, group_idx] = k_zero
self.v_zeros[layer_idx][:batch_size, :, group_idx] = v_zero
# Update seq lens
new_seq_len = start_pos + seq_len
self.seq_lens[:batch_size] = new_seq_len
# Return dequantized full cache
return self.get(layer_idx, batch_size)
def get(
self,
layer_idx: int,
batch_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get dequantized KV for a layer."""
if not self._allocated:
raise RuntimeError("Cache not allocated")
batch_size = batch_size or self.max_batch_size
seq_len = self.seq_lens[0].item()
if seq_len == 0:
return None, None
# Get quantized cache
k_quant = self.k_cache[layer_idx][:batch_size, :, :seq_len, :]
v_quant = self.v_cache[layer_idx][:batch_size, :, :seq_len, :]
# Get scales (simplified: use average)
num_groups = (seq_len + self.group_size - 1) // self.group_size
k_scale = self.k_scales[layer_idx][:batch_size, :, :num_groups].mean(dim=-1, keepdim=True)
v_scale = self.v_scales[layer_idx][:batch_size, :, :num_groups].mean(dim=-1, keepdim=True)
k_zero = None
v_zero = None
if not self.symmetric:
k_zero = self.k_zeros[layer_idx][:batch_size, :, :num_groups].mean(dim=-1, keepdim=True)
v_zero = self.v_zeros[layer_idx][:batch_size, :, :num_groups].mean(dim=-1, keepdim=True)
# Dequantize
k = self.dequantize(k_quant, k_scale, k_zero, original_dim=self.head_dim)
v = self.dequantize(v_quant, v_scale, v_zero, original_dim=self.head_dim)
return k, vPer-tensor (coarsest):
# Single scale for entire KV tensor
scale = tensor.abs().max() / 127.0
quantized = (tensor / scale).round()
Pros: Minimal overhead
Cons: Poor for varied distributionsPer-token:
# Scale per sequence position
scale = tensor.abs().max(dim=-1, keepdim=True)[0] / 127.0
quantized = (tensor / scale).round()
Pros: Adapts to token variation
Cons: More scales to storePer-channel (grouped):
# Scale per group of head_dim
groups = tensor.reshape(-1, group_size)
scale = groups.abs().max(dim=-1, keepdim=True)[0] / 127.0
quantized = (groups / scale).round()
Pros: Best quality
Cons: Most overheadLlama-2-7B, WikiText-2 perplexity:
Baseline (FP16): 5.47
INT8:
Per-tensor: 5.89 (+7.7%)
Per-token: 5.62 (+2.7%)
Per-channel-128: 5.51 (+0.7%)
INT4:
Per-tensor: 8.23 (+50%)
Per-token: 6.42 (+17%)
Per-channel-64: 5.78 (+5.7%)
Recommendation: Per-token or per-channel-128 for INT8, per-channel-64 for INT4
Original (FP16): 2 bytes/value
FP8: 1 byte → 2x compression
INT8: 1 byte → 2x compression
INT4: 0.5 byte → 4x compression
With scale overhead (per-token):
INT8: 1.002 bytes/value → 1.996x actual
INT4: 0.502 bytes/value → 3.984x actual
Task: HumanEval (pass@1)
FP16: 26.8%
INT8:
Symmetric: 26.5% (-0.3%)
Asymmetric: 26.7% (-0.1%)
INT4:
Symmetric: 25.2% (-1.6%)
Asymmetric: 25.8% (-1.0%)
+Outliers: 26.3% (-0.5%)
WikiText-2 perplexity (Llama-2-7B):
Compression Perplexity Delta
FP16 (1x) 5.47 -
INT8 (2x) 5.51 +0.7%
INT4 (4x) 5.78 +5.7%
INT4+outlier 5.59 +2.2% ← Best trade-off
INT2 (8x) 7.34 +34% (not recommended)
Different layers have different quantization sensitivity:
Early layers: Low sensitivity (embeddings are robust)
Middle layers: Medium sensitivity
Late layers: High sensitivity (prediction head nearby)
Strategy: Mixed precision
Layers 0-10: INT4
Layers 11-20: INT8
Layers 21-31: FP16
→ 3.2x compression, -0.9% quality
Llama-2-7B, batch_size=32, seq_len=2048:
Component FP16 INT8 INT4
Model weights 14 GB 14 GB 14 GB
KV cache 16 GB 8 GB 4 GB
Total 30 GB 22 GB 18 GB
Savings - 27% 40%
Increased capacity:
Max batch @ 40GB: 64 → 96 → 128
Max seq_len @ 32: 2048 → 3072 → 4096
Llama-2-7B generation (512 tokens):
FP16: 51.2s (100%)
INT8:
Naive: 55.7s (+8.8%) ← Slow dequantization
Optimized: 50.1s (-2.2%) ← Fused kernels
INT4:
Naive: 59.3s (+15.8%) ← Unpacking overhead
Optimized: 51.8s (+1.2%) ← Acceptable
Key: Must use fused quantization kernels!
Batch throughput (tokens/sec):
Batch FP16 INT8 INT4
1 20 19 18
8 145 158 172
32 520 680 840
128 OOM 1420 2150 ← Enables larger batches
INT8: +30% throughput at typical batch sizes
INT4: +60% throughput (due to larger batches)
Fused quantization kernel:
__global__ void fused_quant_dequant_kernel(
const half* input,
half* output,
float* scales,
int size,
int group_size
) {
// Compute scale
float local_max = 0.0f;
for (int i = 0; i < group_size; i++) {
local_max = fmaxf(local_max, fabsf(__half2float(input[i])));
}
float scale = local_max / 127.0f;
scales[blockIdx.x] = scale;
// Quantize and dequantize in one pass
for (int i = 0; i < group_size; i++) {
float val = __half2float(input[i]);
int8_t quant = __float2int_rn(val / scale);
output[i] = __float2half(quant * scale);
}
}
// 3x faster than separate quantize + dequantizefrom vllm import LLM, SamplingParams
llm = LLM(
"meta-llama/Llama-2-7b-hf",
kv_cache_dtype="int8", # or "fp8", "int4"
quantization_param_path="./scales.pt"
)
outputs = llm.generate(prompts, SamplingParams())import tensorrt_llm
# Build engine with quantized KV cache
builder_config = tensorrt_llm.BuilderConfig(
kv_cache_type="int8",
kv_cache_quant_algo="per_token",
)
engine = tensorrt_llm.build(model, builder_config)from transformers import AutoModelForCausalLM
from nexus.components.inference import QuantizedKVCache
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Attach quantized cache
quantized_cache = QuantizedKVCache(
num_layers=32,
max_batch_size=32,
max_seq_len=2048,
num_heads=32,
head_dim=128,
quant_type="int8"
)
model.quantized_kv_cache = quantized_cacheLlama-2 models, batch=32, seq_len=2048:
Model FP16 INT8 INT4 Savings
7B 30 GB 22 GB 18 GB 27-40%
13B 52 GB 38 GB 30 GB 27-42%
70B 180 GB 130 GB 100 GB 28-44%
HumanEval (pass@1):
Model FP16 INT8 INT4 INT4+outlier
Llama-2-7B 26.8 26.5 25.2 26.3
Llama-2-13B 29.9 29.6 28.1 29.2
Llama-2-70B 47.5 47.1 45.2 46.8
Average degradation:
INT8: -0.5%
INT4: -2.8%
INT4+outlier: -0.9% ← Recommended
Model FP16 INT8 INT4
Llama-2-7B 45.2 45.0 44.1
Llama-2-13B 54.8 54.5 53.2
Llama-2-70B 68.9 68.6 66.8
WikiText-2 (lower is better):
Model FP16 INT8 INT4
Llama-2-7B 5.47 5.51 5.78
Llama-2-13B 4.88 4.91 5.12
Serving configuration:
Setup: Llama-2-13B, A100 80GB, batch_size=32
FP16 INT8 Improvement
Memory: 52 GB 38 GB -27%
Max batch: 32 48 +50%
Throughput: 640 t/s 920 t/s +44%
Latency: 98ms 96ms -2%
Quality (MMLU): 54.8% 54.5% -0.5%
Recommendation: INT8 is production-ready!
Monthly serving cost (1M requests/day, 512 tokens avg):
FP16:
GPU hours: 12,000/month
Cost: $24,000/month
INT8 (1.4x throughput):
GPU hours: 8,571/month
Cost: $17,142/month
Savings: $6,858/month ($82K/year)
INT4 (2x throughput):
GPU hours: 6,000/month
Cost: $12,000/month
Savings: $12,000/month ($144K/year)
ROI: Immediate (no implementation cost if using vLLM/TensorRT)
Use INT8 when: ✅ Need memory savings with minimal quality loss ✅ Want to increase batch size or context length ✅ Production deployment (tested and stable) ✅ Using modern serving frameworks (vLLM, TensorRT-LLM)
Use INT4 when: ✅ Memory extremely constrained ✅ Willing to accept 1-3% quality loss ✅ Serving very long contexts (8K+ tokens) ✅ Can tune per-channel quantization
Don't quantize KV when: ❌ Quality is absolute priority ❌ Memory not constrained ❌ Already using other memory optimizations (may be redundant)
# Production (balanced)
CONFIG_PROD = {
'quant_type': 'int8',
'symmetric': True,
'group_size': 128,
}
# Memory-constrained (aggressive)
CONFIG_MEMORY = {
'quant_type': 'int4',
'symmetric': False,
'group_size': 64,
'outlier_threshold': 0.01, # Keep top 1% in FP16
}
# Quality-priority (conservative)
CONFIG_QUALITY = {
'quant_type': 'fp8', # Requires H100
'symmetric': True,
'group_size': None, # Per-tensor
}Quantized KV cache is a production-ready optimization:
Key Takeaways:
- 2-4x memory savings with INT8/INT4
- Minimal quality loss (<1% with INT8)
- Higher throughput from larger batches
- Widely supported in serving frameworks
Best Practices:
- Start with INT8 symmetric per-token
- Use INT4 only if memory-critical
- Always measure quality on your tasks
- Combine with other optimizations (PagedAttention, batching)
Production Checklist:
- Benchmark quality on validation set
- Profile memory savings
- Test with target batch sizes
- Validate with serving framework
- Monitor production metrics
Papers:
- KV Cache Quantization
- GPTQ - Weight quantization (related)
- LLM.int8() - Quantization techniques
Code: