A minimal implementation of Fully Sharded Data Parallel (FSDP) in the style of PyTorch FSDP2.
At a high level, parameters are unsharded on demand for computation and resharded afterwards, during both forward and backward stages.
- Drop-in wrapper
fully_shard()
for PyTorch modules - Automatic sharding/unsharding during training
- Minimal reference (~1 file, <300 LOC).
# Install dependencies using uv
uv sync
source .venv/bin/activate
# Run training example (2 GPUs)
torchrun --nproc_per_node=2 train.py
# Install dependencies using uv
uv sync
# Run all tests
uv run pytest -v
.
├── src/
│ └── fsdp.py # Core FSDP implementation
├── tests/
│ └── test_fsdp.py # Unit tests for FSDP
├── train.py # Example training script with LLaMA
└── README.md