HeavyBall 3.0.0
Highlights
- Simplified public API:
Foreach*prefixes removed, short names are now the canonical classes - New optimizers:
HyperBallAdamW,MuonAdamW,LATHER,PSGDPRO LATHER, "Lie-group Adam Through Harmonic Eigenbasis Rotations", performs AdamW in the PSGD eigenbasisRoute-based param dispatch replaces manualSplitOptfor mixed-architecture optimizersScheduleFreeandMSAMmode switches are now idempotent (eval()twice is safe)- Higher-precision PSGD preconditioner updates
- New
consume_gradoption:step()clearsp.gradafter consuming it by default; setconsume_grad=Falseto keep gradients attached after the step orig_shapesis now an explicit documented optimizer argument; usecapture_param_shapes(...)before wrapping models with sharding backends that do not preserve original parameter shapestorch.compile-friendly step with automatic eager fallback for init/preconditioning
Release benchmarks
HeavyBall 3.0.0 was benchmarked against HeavyBall 2.0.0 and torch.optim with
benchmarks/bench_release_optimizers.py, with compiled AdamW step latency
dropping from 10.63 ms in HeavyBall 2.0.0 to 4.15 ms in HeavyBall 3.0.0, a 2.56x speedup.
Breaking changes
Class renames
Every Foreach* class is renamed to its short form. The old short-form aliases (which existed
in 2.x) keep working, only the Foreach* imports break.
| 2.x name | 3.x name |
|---|---|
ForeachAdamW |
AdamW |
ForeachNAdam |
NAdam |
ForeachAdEMAMix |
AdEMAMix |
ForeachAdamC |
AdamC |
ForeachRMSprop |
RMSprop |
ForeachSFAdamW |
SFAdamW |
ForeachADOPT |
ADOPT |
ForeachMuon |
Muon |
ForeachLaProp |
LaProp |
ForeachSignLaProp |
SignLaProp |
ForeachSOAP |
SOAP |
ForeachSOAPNAdam |
SOAPNAdam |
ForeachSOAPAdEMAMix |
SOAPAdEMAMix |
ForeachSOLP |
SOLP |
ForeachPSGDKron |
PSGDKron |
ForeachPSGDLRA |
PSGDLRA |
Removed optimizer classes
These were thin subclasses that only set a class-level default. Use the parent class with the
corresponding constructor argument instead.
| 2.x class | 3.x equivalent |
|---|---|
PaLMForeachSFAdamW / PaLMSFAdamW |
SFAdamW(..., palm=True) |
PaLMForeachSOAP / PaLMSOAP / PalmForEachSoap |
SOAP(..., palm=True) |
PrecondScheduleForeachSOAP / PrecondScheduleSOAP |
SOAP(..., use_precond_schedule=True) |
PrecondSchedulePaLMForeachSOAP / PrecondSchedulePaLMSOAP |
SOAP(..., palm=True, use_precond_schedule=True) |
ForeachPurePSGD / PurePSGD |
PSGDKron(..., exp_avg_input=False) |
ForeachCachedPSGDKron / CachedPSGDKron |
PSGDKron(...) (caching is now the default) |
ForeachDelayedPSGD / DelayedPSGD |
PSGDKron(..., delayed=True) |
ForeachCachedDelayedPSGDKron / CachedDelayedPSGDKron |
PSGDKron(..., delayed=True) |
ForeachCachedNewtonPSGD / NewtonPSGDKron |
PSGDKron(..., hessian_approx=True) |
NewtonHybrid2PSGDKron |
PSGDKron(..., hessian_approx=True, hvp_interval=2) |
ForeachDelayedPSGDLRA / DelayedPSGDLRA |
PSGDLRA(..., delayed=True) |
ForeachNewtonPSGDLRA / NewtonPSGDLRA |
PSGDLRA(..., hessian_approx=True) |
NewtonHybrid2PSGDLRA |
PSGDLRA(..., hessian_approx=True, hvp_interval=2) |
Renamed parameters
| 2.x parameter | 3.x parameter | Notes |
|---|---|---|
foreach |
multi_tensor |
Passing foreach emits a FutureWarning and remaps automatically |
Removed parameters
These raise TypeError if passed. They were either unused or replaced by better defaults.
| Parameter | Previously on | Notes |
|---|---|---|
stochastic_schedule |
SOAP, PSGDKron, PSGDLRA | Deterministic accumulation schedule is now the only mode |
normalize_grads |
SOAP variants | Was unused in the transform pipeline |
correct_bias |
SOAP variants | Was unused in the transform pipeline |
inverse_free |
PSGDKron | Use quad_torch or PSGDPRO for inverse-free PSGD |
adaptive |
PSGDKron | Removed |
Helper sampler kwargs
These compatibility kwargs were removed from heavyball.helpers samplers and now raise
TypeError.
| Class | Removed kwargs |
|---|---|
BoTorchSampler |
constraints_func, consider_running_trials |
HEBOSampler |
constant_liar |
ImplicitNaturalGradientSampler |
lr, warn_independent_sampling |
AutoSampler |
constraints_func |
Chainable API renames
| 2.x name | 3.x name |
|---|---|
Branch |
Parallel |
Behavioral changes
- ScheduleFree / MSAM
eval()/train(): Now idempotent. Callingeval()twice no
longer flips back to train mode. Both methods accept amodeargument matching
nn.Module.train(mode)and returnself. - Gradient lifetime:
consume_grad=Trueis available on all optimizers and clearsp.grad
duringstep()once the gradient has been consumed. Setconsume_grad=Falseif your code
reads gradients after stepping or relies on them remaining attached. - Sharded parameter shapes: Built-in optimizers now expose
orig_shapesexplicitly. Use
capture_param_shapes()before wrapping parameters if your sharding backend hides original
shapes. - PSGD dampening:
dampen_graddefault changed from2**-13to1e-9, and dampening
epsilon usestorch.finfo(float32).epsregardless of input dtype. This improves
preconditioner accuracy but may change convergence behavior.
Checkpoint migration
Use the migration CLI to convert 1.x or 2.x checkpoints:
python scripts/migrate_optimizer_state.py <checkpoint.pt> <OptimizerClass>Old class names (including all aliases listed above) are resolved automatically.
The foreach → multi_tensor key rename in param groups is handled automatically.
Upgrade checklist
- Replace
from heavyball import Foreach*with the short name (e.g.,ForeachAdamW→AdamW) - Replace
foreach=withmulti_tensor=in constructor calls - Replace removed subclass instantiations with parent + kwargs (see table above)
- Remove any
stochastic_schedule,normalize_grads,correct_bias,inverse_free, oradaptivekwargs - Replace
Branch(...)withParallel(...)in custom chainable code - Migrate checkpoints:
python scripts/migrate_optimizer_state.py <ckpt> heavyball.<Optimizer> - If you relied on
eval(); eval()toggling back to train mode, update your code - If your training loop reads
p.gradafterstep(), passconsume_grad=False - Remove obsolete compatibility kwargs from
heavyball.helperssamplers