Skip to content

Adding Layerwise GN (https://arxiv.org/pdf/2510.09378) #59

@switiz

Description

@switiz

Is your feature request related to a problem? Please describe.

Large-batch LLM pretraining currently relies on first-order or approximate second-order optimizers (e.g., AdamW, SOAP, Muon, Shampoo).
However, even the strongest of these still under-utilize curvature information — requiring significantly more iterations to reach the same loss compared to a full Gauss-Newton (GN) preconditioner.

For example, the paper The Potential of Second-Order Optimization for LLMs: A Study with Full Gauss-Newton (arXiv:2510.09378) reports that:

  • GN reaches a loss of 3.25 in 54 steps
  • SOAP requires 292 steps to reach the same loss
    → roughly 5.4× fewer iterations, and GN also extends the critical batch size.

Describe the solution you’d like

Add a Layerwise Gauss-Newton (GN) preconditioning mode, with an optional “Full GN (oracle)” flag for research comparison.

1. Optimizer Core

  • Integrate GN updates via JVP-based preconditioning, avoiding explicit Hessian materialization.
  • GN acts as a plug-in preconditioner that wraps existing optimizers (SOAP / Muon / Shampoo).
  • Use inner-loop Muon or AdamW to minimize the quadratic objective under GN preconditioning, with optional line search for stability.

2. Layerwise GN Variant (Default)

  • Compute per-layer GN updates (ignore cross-layer curvature).
  • Nearly matches Full GN on medium-scale LLMs and large batches — requiring only ~1.4× more steps than Full GN but ~3.4× fewer steps than SOAP.

3. Why GN?

  • GN captures curvature from the loss surface only (positive semi-definite),
    avoiding negative-curvature instability of full Newton updates while significantly improving iteration efficiency at scale.

Describe alternatives you’ve considered

  • Existing SOAP / Muon / Shampoo implementations already provide approximate second-order preconditioning,
    but they lack full curvature fidelity and plateau earlier in large-batch regimes.
  • Extending these optimizers with GN-based preconditioning could preserve backward compatibility
    while improving step efficiency.
  • The GN-prox-linear variant was analyzed but offered little gain, suggesting the loss curvature alone captures most of the benefit.

Additional context

  • Treat Full GN as a research-only configuration (≈ 4–5× slower wall-clock).
  • Layerwise GN is the practical, scalable variant to evaluate for improved step efficiency and batch scaling.
  • Recommended evaluation setup: 45M and 150M-parameter models under large-batch regimes.
  • GN update formula:
    [
    \theta_{t+1} = \theta_t - G^{-1} g,\quad G = J^\top \nabla_z^2 L, J
    ]
    implemented efficiently via JVPs without explicit Hessian storage.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions