-
Notifications
You must be signed in to change notification settings - Fork 3.8k
Feature Request: FlashOptim - optimizer memory reduction #4171
Description
Summary
Request to investigate integrating FlashOptim into Megatron Core's optimizer infrastructure. FlashOptim reduces AdamW memory from 16 to 7 bytes/parameter (5 with gradient release) by combining ULP-bounded master weight splitting with companded 8-bit optimizer state quantization — with no measurable quality loss and no wall-clock overhead.
Motivation
Optimizer states dominate training memory for large models. FlashOptim addresses this with two complementary techniques:
Master Weight Splitting: Splits FP32 weights into BF16 + INT8 error correction. Key insight: reconstruction error is bounded by the unit of least precision (ULP), so the exponent bits in the correction term are redundant. Achieves ~24-bit effective precision with bitwise-perfect reconstruction in 99.92% of values.
Companded State Quantization: Applies nonlinear companding before INT8 quantization of optimizer states — 2x/(1+|x|) for momentum (signed), √x for variance (unsigned). This is critical: linear quantization of Adam states causes divergence due to heavy-tailed variance distributions.
Results:
| Task | Model | Reference | FlashOptim |
|---|---|---|---|
| ImageNet | ResNet-50 | 77.01% | 77.16% |
| Pretraining | GPT-2 124M | 3.263 loss | 3.265 loss |
| Finetuning | Llama-3.1-8B | 75.09% GSM8k | 74.98% GSM8k |
All within measurement variance. Optimizer step is actually 8% faster on Llama-3.1-8B (fused Triton kernels). Supports SGD, AdamW, and Lion.
Memory Savings (bytes/parameter)
| AdamW | FlashAdamW | FlashAdamW + grad release | |
|---|---|---|---|
| Total | 16 | 7 | 5 |
Achieved via BF16+INT8 master weight splitting and companded INT8 optimizer states.
Requested Feature
Investigate adding FlashAdamW as a supported optimizer in Megatron-Core's distributed optimizer, including compatibility with existing distributed checkpointing. FlashOptim >= 0.1.3 provides native DTensor support for PyTorch DCP/FSDP2 integration.
References
- FlashOptim Paper (Gonzalez Ortiz, Gupta, Rinard, Blalock)
- GitHub: databricks/flashoptim
- Automodel FlashOptim integration - merged PR adding
FlashAdamWwith FSDP2/DCP support