v0.4.0: Memory-Optimal Model Loading for SFT + OSFT đź’ľ
v0.4.0: Memory-Optimal Model Loading for SFT + OSFT đź’ľ
Summary: This release overhauls distributed initialization into a unified prepare → wrap → finalize pipeline that applies to both standard SFT and OSFT. The new
flow keeps all non-workload ranks on meta tensors until the last moment, rehydrates parameters shard-by-shard, and drives up to an 80% reduction in peak CPU + GPU
memory use when materializing large checkpoints. Alongside the loader refactor we hardened GPT-OSS conversions, fixed OSFT corner cases, refreshed docs, and sped up
our CI/tooling so the new path is well covered.
Highlights
- ♻️ Unified three-phase distributed initialization:
prepare_model_for_fsdp2,wrap_fsdp2, andfinalize_model_initializationnow orchestrate rank-aware
loading for both SFT and OSFT so only rank 0 ever touches the full checkpoint and all other ranks stay light-weight onmetatensors. - 📉 Up to 80% lower memory footprint: Materializing each shard directly inside its FSDP2 wrapper eliminates the previous all-gather spikes, keeping both CPU
staging buffers and GPU allocations flat even on 70B-class models. - đź§ Safer lazy-init internals: The new
fsdp2_lazy_inithelpers andModelInitializationContexttrack initialization state, repair stray tensor attributes,
and makefinalize_model_initializationidempotent for both OSFT factors and vanilla weights. - 📚 Contributor polish: Fresh distributed-initialization documentation, README callouts, and CI/test updates make it easier to reason about the pipeline while
keepingosft_memory_efficient_initfully automatic and deprecated.
Performance Improvements
Unified Three-Phase Distributed Initialization
- Phase 1 (
prepare_model_for_fsdp2) lets rank 0 load the checkpoint on CPU, record the state dict/buffers insideModelInitializationContext, and immediately free
the heavyweight module while every other rank instantiates the model onmeta. - Phase 2 (
wrap_fsdp2) now focuses solely on activation checkpointing each transformer block, building the device mesh, and callingfully_shard, so sharding
incurs zero data movement. - Phase 3 (
finalize_model_initialization) streams the stored payload back into the sharded modules—broadcasting vanilla tensors for SFT and driving the distributed
SVD path for OSFT—without ever inflating a full-model copy.
Memory-Safe OSFT Internals
- Introduced
fsdp2_lazy_inittagging so every model knows whether it is in SFT or OSFT lazy-init mode, preventing accidental re-entry and enabling clean teardown/
reinitialization. _sanitize_meta_attribute_aliasesnow repairs stray tensors (buffers/parameters that lived onmetaor CPU) before wrapping, stopping the silent OOM regressions
caused by attribute clones that FSDP2 could not shard.- OSFT models keep their dense payload exclusively on rank 0 until
set_model_state_dictreceives the distributed SVD output, so non-owned tensor references never
leak into other ranks.
Activation Checkpointing + FSDP2 Integration
- Every transformer block is automatically wrapped with
checkpoint_wrapperas part of Phase 2, dramatically shrinking activation memory and aligning the FSDP2
graph between SFT and OSFT. - Device-mesh creation, mixed-precision policies, activation checkpointing, and
fully_shardnow execute in isolation so the finalize phase can focus on populating
weights and clearing buffers.
GPT-OSS & Control-Plane Reliability
- GPT-OSS MXFP4 conversion now enforces consistent float32 tie-breaking, bias handling, and dtype casting so checkpoint loads no longer hit mismatched tensor errors.
- A dedicated CPU-friendly (Gloo) control process group carries object/broadcast traffic, eliminating NCCL stalls when dispatching OSFT contexts or checkpoint
metadata.
Developer Experience & Tooling
- Added
docs/distributed_initialization.mdplus a README callout that diagrams the three phases for contributors and documents both SFT and OSFT entry points. - The
osft_memory_efficient_initflag is officially deprecated—distributed jobs always use the lazy-init path, while single-node/non-distributed runs default to
eager loading with helpful warnings. - API wrappers and CLI plumbing emit clear warnings when legacy flags are present and provide actionable guidance for non-distributed scenarios.
Testing & CI
- Revamped
tests/test_model_initialization.py,tests/test_integration_small_models.py, and related harnesses to cover the new lazy-init state machine, GPT-2/Qwen
variants, and both SFT/OSFT permutations. - Added targeted
toxenvironments (lint-changed,format-check-changed) and taught the GitHub Actions workflows to fetch full history so lint/format checks only
touch files that changed, cutting CI runtime. - Orthogonality utilities now surface debug info immediately so regression tests fail loudly when distributed SVD inputs misbehave.
Bug Fixes & Reliability
- Non-distributed OSFT execution once again bypasses the lazy-init path and honors
fsdp2_lazy_init=False, preventing hangs when torch.distributed is absent. - Fixed dtype mismatches and bias handling while loading GPT-OSS checkpoints, avoiding corrupted experts after MXFP4 quantization.
- Progress bars now advance the epoch counter correctly after each loop, keeping CLI/UI metrics accurate.
- Tightened error handling around distributed saves, synchronized barriers, and cleanup so we do not leak memory when exceptions propagate.
Upgrade Notes
- Remove
osft_memory_efficient_initfrom configs—distributed memory-efficient initialization is automatic and the flag will be removed in v0.5.0. - Use
torchrun/distributed launches whenever you want the lazy-init benefits; non-distributed jobs should explicitly keepfsdp2_lazy_init=False. - Historical
CHANGELOG.mdentries now live in GitHub Releases (this document), so reference release tags for future diffs.
Contributors
Full Changelog: v0.3.1...v0.4.0
Let me know if you’d like any extra callouts (benchmarks, GIFs, code samples, etc.) before you publish the release.