Skip to content

Incorrect Results with FlexDecoding #233

@zaptrem

Description

@zaptrem

@BoyuanFeng

Summary

The KVCache.update() method returns the entire cache buffer including uninitialized (zero) positions, which causes significant numerical errors when using flex_attention. While this doesn't visibly affect discrete token generation (due to argmax), it:

  1. Produces incorrect attention values (101% relative error)
  2. Wastes computation on invalid cache positions
  3. Would cause severe issues for generation with real models esp. over longer contexts

Reproduction

import torch
from torch.nn.attention.flex_attention import flex_attention, create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16
batch_size, num_heads, head_dim = 1, 32, 128
max_seq_length = 2048
current_position = 100

# Create query and KV cache
q = torch.randn(batch_size, num_heads, 1, head_dim, device=device, dtype=dtype)
k_cache = torch.zeros(batch_size, num_heads, max_seq_length, head_dim, device=device, dtype=dtype)
v_cache = torch.zeros(batch_size, num_heads, max_seq_length, head_dim, device=device, dtype=dtype)

# Fill only valid positions (0-99)
k_cache[:, :, :current_position] = torch.randn(batch_size, num_heads, current_position, head_dim, device=device, dtype=dtype)
v_cache[:, :, :current_position] = torch.randn(batch_size, num_heads, current_position, head_dim, device=device, dtype=dtype)

# Test 1: Current GPT-Fast approach (full cache)
def offset_causal_mask(b, h, q, kv):
    return (q + current_position - 1) >= kv

mask_full = create_block_mask(offset_causal_mask, B=batch_size, H=None, Q_LEN=1, KV_LEN=max_seq_length, device=device)
mask_full.seq_lengths = (1, max_seq_length)  # As done in generate.py
output_full = flex_attention(q, k_cache, v_cache, block_mask=mask_full)

# Test 2: Correct approach (sliced cache)
k_sliced = k_cache[:, :, :current_position]
v_sliced = v_cache[:, :, :current_position]
mask_sliced = create_block_mask(causal_mask, B=batch_size, H=None, Q_LEN=1, KV_LEN=current_position, device=device)
mask_sliced.seq_lengths = (1, current_position)
output_sliced = flex_attention(q, k_sliced, v_sliced, block_mask=mask_sliced)

# Compare results
error = (output_full - output_sliced).abs()
print(f"Mean error: {error.mean().item():.6f}")
print(f"Relative error: {(error.mean() / output_sliced.abs().mean() * 100).item():.1f}%")
print(f"Full cache std: {output_full.std().item():.6f}")
print(f"Sliced cache std: {output_sliced.std().item():.6f}")

Results

Mean error: 0.816406
Relative error: 101.0%
Full cache std: 0.152802
Sliced cache std: 1.016770

The full cache approach produces completely different results with 101% relative error!

While slicing the cache fixes the issue, now we have shapes that change every step which is way slower. It probably breaks the flash decoding kernel assumptions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions