-
Notifications
You must be signed in to change notification settings - Fork 9
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Is your feature request related to a problem? Please describe.
Large-batch LLM pretraining currently relies on first-order or approximate second-order optimizers (e.g., AdamW, SOAP, Muon, Shampoo).
However, even the strongest of these still under-utilize curvature information — requiring significantly more iterations to reach the same loss compared to a full Gauss-Newton (GN) preconditioner.
For example, the paper The Potential of Second-Order Optimization for LLMs: A Study with Full Gauss-Newton (arXiv:2510.09378) reports that:
- GN reaches a loss of 3.25 in 54 steps
- SOAP requires 292 steps to reach the same loss
→ roughly 5.4× fewer iterations, and GN also extends the critical batch size.
Describe the solution you’d like
Add a Layerwise Gauss-Newton (GN) preconditioning mode, with an optional “Full GN (oracle)” flag for research comparison.
1. Optimizer Core
- Integrate GN updates via JVP-based preconditioning, avoiding explicit Hessian materialization.
- GN acts as a plug-in preconditioner that wraps existing optimizers (SOAP / Muon / Shampoo).
- Use inner-loop Muon or AdamW to minimize the quadratic objective under GN preconditioning, with optional line search for stability.
2. Layerwise GN Variant (Default)
- Compute per-layer GN updates (ignore cross-layer curvature).
- Nearly matches Full GN on medium-scale LLMs and large batches — requiring only ~1.4× more steps than Full GN but ~3.4× fewer steps than SOAP.
3. Why GN?
- GN captures curvature from the loss surface only (positive semi-definite),
avoiding negative-curvature instability of full Newton updates while significantly improving iteration efficiency at scale.
Describe alternatives you’ve considered
- Existing SOAP / Muon / Shampoo implementations already provide approximate second-order preconditioning,
but they lack full curvature fidelity and plateau earlier in large-batch regimes. - Extending these optimizers with GN-based preconditioning could preserve backward compatibility
while improving step efficiency. - The GN-prox-linear variant was analyzed but offered little gain, suggesting the loss curvature alone captures most of the benefit.
Additional context
- Treat Full GN as a research-only configuration (≈ 4–5× slower wall-clock).
- Layerwise GN is the practical, scalable variant to evaluate for improved step efficiency and batch scaling.
- Recommended evaluation setup: 45M and 150M-parameter models under large-batch regimes.
- GN update formula:
[
\theta_{t+1} = \theta_t - G^{-1} g,\quad G = J^\top \nabla_z^2 L, J
]
implemented efficiently via JVPs without explicit Hessian storage.
sbhavani
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request