Skip to content

Muhtasham/char-prefix-conditioning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Character Prefix Conditioning

A minimal, efficient implementation of character prefix conditioning (CPC) for code completion, inspired by the Cursor blog.

Overview

When using a language model for code completion, we typically want the model to produce a completion that begins with what the user has typed. However, modern language models operate on sequences of tokens, not characters, so naively tokenizing the user's input and sending it to the model produces wrong results if the user's cursor doesn't happen to lie on a token boundary.

CPC is an algorithm for sampling a sequence of tokens conditioned on a character prefix, ensuring completions always start with the user's typed prefix—even if it doesn't align with token boundaries.

How It Works

flowchart TD
    A[User types: 'import num'] --> B[Tokenize prompt]
    B --> C[Get model predictions for next token]
    C --> D{Check overlap rule for each token}
    D -->|Valid tokens| E[Renormalize probabilities]
    D -->|No valid tokens| F[Raise ConstraintError]
    E --> G[Sample from valid distribution]
    G --> H{Prefix satisfied?}
    H -->|No| C
    H -->|Yes| I[Continue unconstrained generation]
    I --> J[Return result]
Loading

The Overlap Rule

The key insight is the overlap rule for determining valid tokens. A token r is valid if:

(S + r) is prefix of P   OR   P is prefix of (S + r)

Where:

  • S = generated string so far
  • r = candidate token
  • P = target prefix
flowchart LR
    subgraph "Case 1: Still building prefix"
        A1["S = 'import '"] --> B1["r = 'nu'"]
        B1 --> C1["S+r = 'import nu'"]
        C1 --> D1["'import nu' is prefix of 'import numpy'"]
        D1 --> E1[VALID]
    end
Loading
flowchart LR
    subgraph "Case 2: Prefix completed"
        A2["S = 'import num'"] --> B2["r = 'py'"]
        B2 --> C2["S+r = 'import numpy'"]
        C2 --> D2["'import num' is prefix of 'import numpy'"]
        D2 --> E2[VALID - any continuation works]
    end
Loading

Architecture

graph TB
    subgraph ModelManager
        M[Model] --> T[Tokenizer]
        T --> TM[Token Map]
        TM --> TR[Prefix Trie]
    end

    subgraph Generation
        P[Prompt] --> CPS[character_prefix_sample]
        CPS --> GAC[_generate_with_autocast]
        GAC --> |Phase 1| CON[Constrained Generation]
        GAC --> |Phase 2| UNC[Unconstrained Generation]
    end

    subgraph "Constraint Checking"
        CON --> TR
        TR --> |O of A_i| VT[Valid Tokens]
        VT --> RN[Renormalize & Sample]
    end
Loading

Key Features

Trie-Based Token Lookup

Instead of O(|V|) iteration over all vocabulary tokens, we use a trie for O(|A_i|) lookup where |A_i| << |V|.

graph TD
    subgraph "Trie Structure"
        R[Root] --> I[i]
        I --> IM[im]
        IM --> IMP[imp]
        IMP --> IMPO[impo]
        IMPO --> IMPOR[impor]
        IMPOR --> IMPORT[import]

        R --> N[n]
        N --> NU[nu]
        NU --> NUM[num]
        NUM --> NUMP[nump]
        NUMP --> NUMPY[numpy]
    end
Loading

Benchmark Results (GPT-2, Apple Silicon MPS):

Method Time Speedup
Baseline O(|V|) 1.84s -
Trie O(|A_i|) 0.98s 46.9%
Trie + KV Caching 0.94s 48.8%

Proper Error Handling

No silent constraint relaxation. If the constraint cannot be satisfied:

flowchart TD
    A[No valid tokens found] --> B{Any valid token exists?}
    B -->|Yes, low prob| C[Use best valid token deterministically]
    B -->|No| D[Raise ConstraintError]
    D --> E[User handles error explicitly]
Loading

Two-Phase Generation

sequenceDiagram
    participant U as User
    participant G as Generator
    participant M as Model

    U->>G: prompt="import", prefix="import num"

    rect rgb(200, 230, 200)
        Note over G: Phase 1: Constrained
        loop Until prefix satisfied
            G->>M: Get next token probs
            M-->>G: Probabilities
            G->>G: Filter by overlap rule
            G->>G: Sample from valid tokens
        end
    end

    rect rgb(200, 200, 230)
        Note over G: Phase 2: Unconstrained
        loop Remaining tokens
            G->>M: Get next token probs
            M-->>G: Probabilities
            G->>G: Sample normally
        end
    end

    G-->>U: "import numpy as np..."
Loading

Setup

uv sync

Usage

from main import ModelManager, character_prefix_sample, ConstraintError

# Initialize and load model
model_manager = ModelManager("gpt2")
model_manager.load_model()

# Generate with character prefix constraint
try:
    result = character_prefix_sample(
        model_manager=model_manager,
        prompt_text="import",
        character_prefix="import num",
        max_new_tokens=50,
        use_trie=True,   # O(|A_i|) lookup (default)
        use_cache=True,  # DynamicCache (default)
    )
    print(result)  # "import numpy as np..."
except ConstraintError as e:
    print(f"Could not satisfy constraint: {e}")

Speculative Decoding (for larger models)

Uses a smaller draft model to propose tokens, verified by the target model in a single forward pass. Implements proper speculative decoding acceptance criterion from Leviathan et al. with residual sampling.

from main import ModelManager, speculative_sample

# Load with draft model for speculative decoding
model_manager = ModelManager("gpt2", draft_model_name="distilgpt2")
model_manager.load_model()

# Generate with speculative decoding (1.5-2x faster on 7B+ models)
result = speculative_sample(
    model_manager=model_manager,
    prompt_text="import",
    character_prefix="import num",
    max_new_tokens=100,
    num_speculative_tokens=10,  # tokens to speculate per iteration
    confidence_threshold=0.4,   # stop drafting if confidence drops
)

Benchmark Results (GPT-2 + DistilGPT2, Apple Silicon MPS):

Test Standard Speculative Speedup
The quick brown fox... 1.57s 1.39s +11.2%
def fibonacci(n):... 0.97s 0.94s +3.5%
In the year 2024... 1.34s 1.22s +9.0%
Total 6.01s 5.76s +4.1%

Note: Speculative decoding benefits scale with model size ratio. GPT-2 (124M) + DistilGPT2 (82M) are too similar for large gains. With 7B+ models and tiny drafters (e.g., Llama-3.1-70B + Qwen2-0.5B), expect 1.5-2x speedup.

Run Tests & Benchmarks

# Run unit tests (no model required)
uv run python main.py --unit-test --no-test

# Run integration test suite
uv run python main.py

# Run benchmarks
uv run python main.py --benchmark

# Skip tests, only benchmark
uv run python main.py --no-test --benchmark

# Run speculative decoding benchmark
uv run python main.py --no-test --draft-model distilgpt2 --benchmark-speculative

# Use different model
uv run python main.py --model gpt2-medium

API Reference

character_prefix_sample()

Parameter Type Default Description
model_manager ModelManager required Loaded model instance
prompt_text str required Initial prompt (can be empty)
character_prefix str required Required prefix constraint
max_new_tokens int required Max tokens to generate
use_trie bool True Use trie for O(|remaining| + |A_i|) lookup
use_cache bool True DynamicCache KV caching (efficient with transformers 4.57+)
use_mixed_precision bool False fp16 mode (disabled by default - marginal benefit on MPS)

batch_generate_samples()

Parameter Type Default Description
model_manager ModelManager required Loaded model instance
prompts List[Tuple] required List of (prompt, prefix, max_tokens) tuples
parallel bool False Use ThreadPoolExecutor for parallel generation
max_workers int 4 Number of parallel workers

speculative_sample()

Dynamic speculative decoding for faster generation on large models (7B+). Uses a draft model to propose tokens, verified by the target model. Based on Intel Labs / HuggingFace research (Transformers 4.45+).

Parameter Type Default Description
model_manager ModelManager required ModelManager with draft_model loaded
prompt_text str required Initial prompt (can be empty)
character_prefix str required Required prefix constraint
max_new_tokens int required Max tokens to generate
num_speculative_tokens int 20 Max tokens to speculate per iteration
confidence_threshold float 0.4 Stop drafting if confidence < threshold
use_trie bool True Use trie for constraint checking

ConstraintError

Raised when the prefix constraint cannot be satisfied. This replaces silent fallback to incorrect results.

Complexity

graph LR
    subgraph "Per Token"
        A[Model Forward: O of 1] --> B[Trie Lookup: O of A_i]
        B --> C[Sample: O of A_i]
    end

    subgraph "Total for n tokens"
        D[Model Calls: O of n]
        E[Constraint Checks: O of n times A_i]
        F[Memory: O of V for trie]
    end
Loading
Operation Baseline With Trie Speculative
Constraint check O(|V|) O(|remaining| + |A_i|) O(|remaining| + |A_i|)
Model calls per token 1 1 1/k (amortized, k=acceptance rate)
Memory O(|V|) O(|V|) for trie O(|V|) + draft model

Test Cases

Test Prompt Prefix Result
Simple prefix import import num import numpy...
Mid-token The model's behav The model's behavi The model's behavior...
F-string print(f"The result is {re print(f"The result is {res print(f"The result is {result}...
Empty prompt `` Once upon a ti Once upon a time...
JSON {"data": {"user {"data": {"username": "test {"data": {"username": "test"}...
Special chars `` Hello, world! Hello, world! ...

About

A minimal, efficient implementation of character prefix conditioning for code completion.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages