v0.2.0a1: GPT-OSS Support & Memory Optimizations π§ [Alpha 1]
Pre-release
Pre-release
π€ Authors
π New Features
- GPT-OSS Model Support: Full support for OpenAI's new open-weight GPT-OSS models (20B and 120B variants) with native MXFP4 quantization implementation
- Memory-Efficient OSFT Initialization: New
osft_memory_efficient_initflag for optimized initialization of large models - Training Dtype Control: New
train_dtypeparameter for switching models to bf16/fp16 training to reduce memory usage (use sparingly as lower precision may impact results) - Pretraining Data Conversion: New
convert_to_pretrain.pyscript for converting conversation datasets to pretraining format
π οΈ API Enhancements
- OSFT Dtype Controls:
osft_upcast_dtypefor computation precision (default: float32)osft_output_dtypefor output precision control
- Enhanced Data Processing: Improved
process_data.pywith additional functionality
π¦ Dependencies
- Updated transformers: Now requires
>=4.55.0 - Added liger-kernel: For optimized operations
- Added kernels package: For flash-attention-3 support
π§ͺ Testing & Quality
- Comprehensive GPT-OSS Testing: New regression tests for MXFP4 conversion accuracy
- Enhanced OSFT Testing: Improved dtype functionality tests
π Statistics
- 1,841 additions, 173 deletions across 14 files
- Major new module:
gpt_oss_utils.py(430+ lines) - Significant enhancements to core training and model setup utilities
π― Focus
This release primarily targets the GPT-OSS 20B model, enabling fine-tuning of OpenAI's open-weight MoE models. The memory optimizations are applicable to any large model training, with the 120B variant potentially supported but not extensively tested.
π‘ GPT-OSS-20B Usage Example
For GPT-OSS 20B training, use the memory-efficient initialization and bfloat16 training dtype:
from mini_trainer.api_train import run_training
from mini_trainer.training_types import TrainingArgs, TorchrunArgs
train_args = TrainingArgs(
model_name="openai/gpt-oss-20b",
osft_memory_efficient_init=True,
train_dtype="bfloat16",
... # other training arguments
)
run_training(torch_args, train_args)π Full Changelog
Full Changelog: v0.1.1...v0.2.0a1