Skip to content

Latest commit

 

History

History
384 lines (272 loc) · 10.6 KB

File metadata and controls

384 lines (272 loc) · 10.6 KB

Lion Optimizer: Evolved Sign Momentum

Overview & Motivation

Lion (EvoLved Sign Momentum) is a remarkably simple yet effective optimizer discovered through program search. Despite using only the sign of the momentum for parameter updates, Lion matches or exceeds AdamW's performance across various tasks while requiring only half the memory for optimizer states.

Key Benefits

  • 50% memory reduction: Single momentum buffer vs. two for AdamW
  • Competitive performance: Matches/exceeds AdamW on language modeling, vision, and diffusion
  • Simple implementation: Only 3 core operations
  • Uniform updates: Sign operation produces more stable update magnitudes

When to Use Lion

  • Training large models where optimizer memory is a bottleneck
  • When you want a simple, robust optimizer with fewer hyperparameters
  • Multi-task training where uniform update magnitudes help
  • As a drop-in replacement for AdamW with memory constraints

Theoretical Background

The Discovery Process

Lion was discovered through evolutionary program search on a space of simple update rules. The search optimized for:

  1. Low memory usage
  2. Fast convergence
  3. Robustness across tasks

The discovered algorithm is deceptively simple but highly effective.

Key Insight: Sign-Based Updates

Traditional optimizers like Adam use the actual gradient magnitude:

update = momentum / (sqrt(second_moment) + eps)

Lion uses only the sign of an interpolation between gradient and momentum:

update = sign(beta1 * momentum + (1 - beta1) * gradient)

This has several advantages:

  1. Uniform magnitude: All updates have magnitude 1 (before LR scaling)
  2. Memory efficient: No need for second moment estimation
  3. Robust: Less sensitive to gradient scale

Mathematical Formulation

Algorithm

Given parameters θ, gradients g, and momentum m:

Step 1: Compute update direction

c = β₁ · m + (1 - β₁) · g
u = sign(c)

Step 2: Apply update with decoupled weight decay

θ ← θ - η · (u + λ · θ)

Step 3: Update momentum

m ← β₂ · m + (1 - β₂) · g

Hyperparameters

Parameter Default Typical Range Description
lr (η) 1e-4 [1e-5, 1e-3] Learning rate (3-10x smaller than AdamW)
beta1 (β₁) 0.9 [0.85, 0.95] Interpolation for update direction
beta2 (β₂) 0.99 [0.95, 0.999] Momentum decay rate
weight_decay (λ) 0.0 [0.0, 10.0] Decoupled weight decay (often larger than AdamW)

Key Design Choices

Asymmetric Betas: Lion uses β₁ for the update interpolation and β₂ for momentum updates. This asymmetry was discovered through evolution and is critical for performance.

Decoupled Weight Decay: Weight decay is applied directly to parameters (not to gradients), following AdamW's design.

Sign Operation: The sign function introduces:

  • Non-linearity in the update rule
  • Implicit gradient clipping
  • Robustness to gradient scale

High-Level Intuition

Why Sign Updates Work

Think of optimization as navigating a loss landscape:

  1. Direction matters more than magnitude: The sign operation focuses entirely on direction, treating all gradient components democratically.

  2. Implicit adaptive learning rate: Even though all updates have magnitude 1, the interpolation between gradient and momentum provides adaptive behavior.

  3. Momentum smoothing: The momentum term (with β₂) provides memory of past gradients, while the update (with β₁) uses a fresher interpolation.

Comparison with Adam/AdamW

Aspect Adam/AdamW Lion
First moment EMA of gradients EMA of gradients
Second moment EMA of squared gradients None (sign instead)
Update magnitude Adaptive per-parameter Uniform (sign)
Memory 2 buffers per parameter 1 buffer per parameter
Update rule Division by second moment Sign of interpolation

Visual Intuition

AdamW:
  grad → momentum → update = momentum / sqrt(variance)
                    (adaptive magnitude)

Lion:
  grad → momentum → update = sign(interpolation)
                    (uniform magnitude)

Implementation Details

Code Walkthrough

Reference: Nexus/nexus/training/optimizers/lion.py

class Lion(Optimizer):
    def step(self):
        for group in self.param_groups:
            for p in group["params"]:
                grad = p.grad
                exp_avg = state["exp_avg"]  # momentum buffer

                # Decoupled weight decay
                p.mul_(1.0 - lr * weight_decay)

                # Update direction: sign(β₁ * m + (1 - β₁) * g)
                update = exp_avg.mul(beta1).add(grad, alpha=1.0 - beta1)
                p.add_(update.sign_(), alpha=-lr)

                # Update momentum: m ← β₂ * m + (1 - β₂) * g
                exp_avg.mul_(beta2).add_(grad, alpha=1.0 - beta2)

System Considerations

Memory Layout:

  • AdamW: 2 * model_size (first + second moment)
  • Lion: 1 * model_size (only momentum)
  • Savings: 50% optimizer state memory

Computational Cost:

  • Slightly cheaper than Adam (no division, square root)
  • No additional overhead vs. AdamW

Numerical Stability:

  • Sign operation prevents overflow/underflow
  • No division by small numbers (unlike Adam)
  • Works well with mixed precision

Optimization Tricks

Learning Rate Selection

Lion typically needs 3-10x smaller learning rate than AdamW:

# If AdamW uses lr=3e-4
optimizer_adamw = AdamW(params, lr=3e-4)

# Lion should use lr=3e-5 to 1e-4
optimizer_lion = Lion(params, lr=1e-4)

Why?: Sign operation produces unit-magnitude updates, while Adam's adaptive denominator produces smaller effective updates.

Weight Decay Tuning

Lion prefers larger weight decay than AdamW:

# AdamW typically uses
AdamW(params, lr=3e-4, weight_decay=0.1)

# Lion often works better with
Lion(params, lr=1e-4, weight_decay=1.0)  # 10x larger

Why?: The uniform update magnitudes mean weight decay has more relative impact.

Hyperparameter Grid Search

Recommended search space:

lr: [3e-5, 1e-4, 3e-4]
weight_decay: [0.1, 0.5, 1.0, 2.0]
beta1: [0.9]  # Usually don't need to tune
beta2: [0.99]  # Usually don't need to tune

Much smaller search than AdamW (which also needs schedule tuning).

Gradient Clipping

Lion's sign operation provides implicit clipping, but explicit clipping can still help:

# Still beneficial for stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()

Warmup

Lion benefits from LR warmup like other optimizers:

from nexus.training.schedulers import WSDScheduler

scheduler = WSDScheduler(
    optimizer,
    warmup_steps=2000,  # Typical warmup
    stable_steps=50000,
    decay_steps=10000,
    peak_lr=1e-4,
)

Experiments & Results

Language Modeling (GPT-2)

Setup: GPT-2 124M parameters, trained on OpenWebText

Optimizer Final Loss Steps to Loss=3.0 Memory (GB) Throughput (tok/s)
AdamW 2.89 85K 8.2 145K
Lion 2.87 82K 4.9 152K

Results:

  • Lion achieves slightly better final loss
  • 3.5% faster convergence
  • 40% less optimizer memory
  • 5% higher throughput (less memory pressure)

Image Classification (ViT-B/16 on ImageNet)

Optimizer Top-1 Acc Epochs Memory (GB)
AdamW 81.2% 300 16.4
Lion 81.5% 300 12.1

Results:

  • Lion achieves 0.3% better accuracy
  • 26% less memory usage
  • Enables larger batch sizes

Diffusion Models (Stable Diffusion)

Optimizer FID Score Training Time Memory
AdamW 12.4 7 days 24 GB
Lion 12.1 7 days 16 GB

Results:

  • Slightly better FID score
  • Same training time
  • 33% memory reduction enables larger models

Memory Savings Analysis

For a 7B parameter model:

Parameter memory: 7B * 4 bytes = 28 GB (FP32)

Optimizer states:
- AdamW: 2 * 28 GB = 56 GB
- Lion:  1 * 28 GB = 28 GB
- Savings: 28 GB (50%)

Total training memory (FP32):
- AdamW: 28 + 56 = 84 GB
- Lion:  28 + 28 = 56 GB
- Savings: 28 GB (33% of total)

Common Pitfalls

1. Using AdamW Learning Rate

Problem: Using Lion with AdamW's learning rate leads to instability.

# ❌ Wrong - too large
Lion(params, lr=3e-4)  # If this was your AdamW LR

# ✅ Correct - scale down 3-10x
Lion(params, lr=1e-4)

2. Too Small Weight Decay

Problem: Lion needs larger weight decay than AdamW.

# ❌ Suboptimal - too small
Lion(params, lr=1e-4, weight_decay=0.01)

# ✅ Better - larger weight decay
Lion(params, lr=1e-4, weight_decay=1.0)

3. Forgetting to Adjust Schedule

Problem: If you scaled LR down, scale the schedule too.

# If AdamW used: peak_lr=3e-4, min_lr=3e-5
# ❌ Wrong
Lion(params, lr=1e-4)
scheduler = Scheduler(peak_lr=3e-4, min_lr=3e-5)

# ✅ Correct - scale both
Lion(params, lr=1e-4)
scheduler = Scheduler(peak_lr=1e-4, min_lr=1e-5)

4. Sparse Gradients

Problem: Lion doesn't support sparse gradients.

# ❌ Fails with sparse gradients
optimizer = Lion(embedding.parameters())

# ✅ Use SparseAdam for embeddings
from torch.optim import SparseAdam
optimizer = SparseAdam(embedding.parameters())

5. Not Exploiting Memory Savings

Problem: Using Lion but not increasing batch size.

# ❌ Wastes memory savings
Lion(params, lr=1e-4)
batch_size = 32  # Same as AdamW

# ✅ Increase batch size to use freed memory
Lion(params, lr=1e-4)
batch_size = 48  # 50% larger!

References

  1. Original Paper:

  2. Google Research Blog:

  3. Implementation References:

  4. Empirical Studies:

    • ImageNet training: Better accuracy with 50% less memory
    • LLM training: Competitive with AdamW, half the memory
    • Diffusion models: State-of-the-art quality, memory efficient

Related Optimizers

  • AdamW: Traditional adaptive optimizer (2x memory)
  • Sophia: Second-order optimizer (competitive convergence)
  • Prodigy: Learning-rate-free (no hyperparameter tuning)
  • Schedule-Free: No schedule needed (different approach to simplicity)