Skip to content

v0.4.0: Memory-Optimal Model Loading for SFT + OSFT đź’ľ

Choose a tag to compare

@RobotSail RobotSail released this 20 Nov 20:36
· 6 commits to main since this release
35657cb

v0.4.0: Memory-Optimal Model Loading for SFT + OSFT đź’ľ

Summary: This release overhauls distributed initialization into a unified prepare → wrap → finalize pipeline that applies to both standard SFT and OSFT. The new
flow keeps all non-workload ranks on meta tensors until the last moment, rehydrates parameters shard-by-shard, and drives up to an 80% reduction in peak CPU + GPU
memory use when materializing large checkpoints. Alongside the loader refactor we hardened GPT-OSS conversions, fixed OSFT corner cases, refreshed docs, and sped up
our CI/tooling so the new path is well covered.

Highlights

  • ♻️ Unified three-phase distributed initialization: prepare_model_for_fsdp2, wrap_fsdp2, and finalize_model_initialization now orchestrate rank-aware
    loading for both SFT and OSFT so only rank 0 ever touches the full checkpoint and all other ranks stay light-weight on meta tensors.
  • 📉 Up to 80% lower memory footprint: Materializing each shard directly inside its FSDP2 wrapper eliminates the previous all-gather spikes, keeping both CPU
    staging buffers and GPU allocations flat even on 70B-class models.
  • đź§  Safer lazy-init internals: The new fsdp2_lazy_init helpers and ModelInitializationContext track initialization state, repair stray tensor attributes,
    and make finalize_model_initialization idempotent for both OSFT factors and vanilla weights.
  • 📚 Contributor polish: Fresh distributed-initialization documentation, README callouts, and CI/test updates make it easier to reason about the pipeline while
    keeping osft_memory_efficient_init fully automatic and deprecated.

Performance Improvements

Unified Three-Phase Distributed Initialization

  • Phase 1 (prepare_model_for_fsdp2) lets rank 0 load the checkpoint on CPU, record the state dict/buffers inside ModelInitializationContext, and immediately free
    the heavyweight module while every other rank instantiates the model on meta.
  • Phase 2 (wrap_fsdp2) now focuses solely on activation checkpointing each transformer block, building the device mesh, and calling fully_shard, so sharding
    incurs zero data movement.
  • Phase 3 (finalize_model_initialization) streams the stored payload back into the sharded modules—broadcasting vanilla tensors for SFT and driving the distributed
    SVD path for OSFT—without ever inflating a full-model copy.

Memory-Safe OSFT Internals

  • Introduced fsdp2_lazy_init tagging so every model knows whether it is in SFT or OSFT lazy-init mode, preventing accidental re-entry and enabling clean teardown/
    reinitialization.
  • _sanitize_meta_attribute_aliases now repairs stray tensors (buffers/parameters that lived on meta or CPU) before wrapping, stopping the silent OOM regressions
    caused by attribute clones that FSDP2 could not shard.
  • OSFT models keep their dense payload exclusively on rank 0 until set_model_state_dict receives the distributed SVD output, so non-owned tensor references never
    leak into other ranks.

Activation Checkpointing + FSDP2 Integration

  • Every transformer block is automatically wrapped with checkpoint_wrapper as part of Phase 2, dramatically shrinking activation memory and aligning the FSDP2
    graph between SFT and OSFT.
  • Device-mesh creation, mixed-precision policies, activation checkpointing, and fully_shard now execute in isolation so the finalize phase can focus on populating
    weights and clearing buffers.

GPT-OSS & Control-Plane Reliability

  • GPT-OSS MXFP4 conversion now enforces consistent float32 tie-breaking, bias handling, and dtype casting so checkpoint loads no longer hit mismatched tensor errors.
  • A dedicated CPU-friendly (Gloo) control process group carries object/broadcast traffic, eliminating NCCL stalls when dispatching OSFT contexts or checkpoint
    metadata.

Developer Experience & Tooling

  • Added docs/distributed_initialization.md plus a README callout that diagrams the three phases for contributors and documents both SFT and OSFT entry points.
  • The osft_memory_efficient_init flag is officially deprecated—distributed jobs always use the lazy-init path, while single-node/non-distributed runs default to
    eager loading with helpful warnings.
  • API wrappers and CLI plumbing emit clear warnings when legacy flags are present and provide actionable guidance for non-distributed scenarios.

Testing & CI

  • Revamped tests/test_model_initialization.py, tests/test_integration_small_models.py, and related harnesses to cover the new lazy-init state machine, GPT-2/Qwen
    variants, and both SFT/OSFT permutations.
  • Added targeted tox environments (lint-changed, format-check-changed) and taught the GitHub Actions workflows to fetch full history so lint/format checks only
    touch files that changed, cutting CI runtime.
  • Orthogonality utilities now surface debug info immediately so regression tests fail loudly when distributed SVD inputs misbehave.

Bug Fixes & Reliability

  • Non-distributed OSFT execution once again bypasses the lazy-init path and honors fsdp2_lazy_init=False, preventing hangs when torch.distributed is absent.
  • Fixed dtype mismatches and bias handling while loading GPT-OSS checkpoints, avoiding corrupted experts after MXFP4 quantization.
  • Progress bars now advance the epoch counter correctly after each loop, keeping CLI/UI metrics accurate.
  • Tightened error handling around distributed saves, synchronized barriers, and cleanup so we do not leak memory when exceptions propagate.

Upgrade Notes

  • Remove osft_memory_efficient_init from configs—distributed memory-efficient initialization is automatic and the flag will be removed in v0.5.0.
  • Use torchrun/distributed launches whenever you want the lazy-init benefits; non-distributed jobs should explicitly keep fsdp2_lazy_init=False.
  • Historical CHANGELOG.md entries now live in GitHub Releases (this document), so reference release tags for future diffs.

Contributors

Full Changelog: v0.3.1...v0.4.0

Let me know if you’d like any extra callouts (benchmarks, GIFs, code samples, etc.) before you publish the release.