Releases: Red-Hat-AI-Innovation-Team/mini_trainer
v0.5.1
Fix: OSFT initialization failure on PyTorch < 2.9
Fixes a bug where OSFT failed to initialize on PyTorch versions below 2.9 due to the use_batch parameter being passed to torch.distributed.send_object_list and recv_object_list, which was only added in PyTorch 2.9.
Changes
- Add compatibility wrappers that detect PyTorch version capability
- Use signature probe with version parsing fallback for accurate detection
- Maintain backward compatibility with older PyTorch versions
Full Changelog: v0.5.0...v0.5.1
v0.5.0: Pretraining Mode & OSFT Speed Boost ⚡
v0.5.0: Pretraining Mode & OSFT Speed Boost ⚡
Summary: This release introduces pretraining-style training with block-based datasets and delivers a significant OSFT performance boost. Users can now train in pretraining mode using fixed datasets by chunking samples into dense token blocks, while OSFT training sees a 2x speedup through optimized factorized linear forward passes.
Highlights
- 🎓 Pretraining Mode: New
--block-sizeCLI flag enables pretraining-style training that packs samples into fixed-size token blocks - ⚡ 2x Faster OSFT Forward Pass: Streamlined factorized linear computation reduces OSFT overhead from 4-5x to 2-4x relative to SFT
- 📦 PretrainingBlockDataset: Dense contiguous packing of input IDs with automatic chunking and padded minibatch collation
- 🔧 AdamW Parameters Exposed: Direct control over optimizer betas, epsilon, and weight decay
New Features
Pretraining Mode
Enable training with pre-training style datasets by @RobotSail in #62
- New
PretrainingConfigdataclass withblock_sizeparameter for chunking samples PretrainingBlockDatasetconcatenates allinput_idsinto a dense token stream and slices into fixed-size blocks- Automatic handling of partial final blocks with proper padding and loss masking
- Pretraining-aware data loaders with correct distributed sampling
- Deterministic startup seeding ensures reproducible training runs
- Tokenizer↔model pad-token alignment prevents silent training issues
Usage:
torchrun --nproc_per_node=4 -m mini_trainer.train \
--model-name-or-path meta-llama/Llama-3.2-1B \
--data-path /path/to/dataset \
--block-size 4096 \
--batch-size 4 \
...API Usage:
from mini_trainer import TrainingArgs, PretrainingConfig, run_training
train_args = TrainingArgs(
model_name_or_path="meta-llama/Llama-3.2-1B",
data_path="/path/to/dataset",
pretraining_config=PretrainingConfig(block_size=4096),
...
)
run_training(torch_args, train_args)Performance Improvements
OSFT Factorized Linear Speedup
2x faster OSFT training by @NikhilNayak-debug in #61
- Removed unnecessary 3D→2D→3D tensor reshaping in
_factorized_linear - Eliminated redundant dtype conversions for OSFT artifacts (U_high, V_high, S_high, etc.)
- Now defaults to MixedPrecisionPolicy dtype for GEMM computation
- Results: OSFT training time reduced from ~4-5x SFT to ~2-4x SFT
Documentation
- Added Pretraining Mode section to README with examples and usage patterns
- Cleaner runtime logs and diagnostics messaging
Testing
- New
tests/test_pretraining_dataset.pywith comprehensive block dataset tests - New
tests/test_data_loader_pretraining.pyfor pretraining data loader validation - Updated existing test suites for compatibility with new features
Bug Fixes
- Preserved RNG state for checkpoint resumption
- Refined loss scaling and logging accuracy
- Fixed formatting issues across touched files
Upgrade Notes
- No breaking API changes
- To use pretraining mode, provide
--block-sizevia CLI or setpretraining_configinTrainingArgs - Dataset format for pretraining expects only an
input_idscolumn (no length matching required)
Contributors
Installation
Through Pip:
uv pip install rhai-innovation-mini-trainer && uv pip install rhai-innovation-mini-trainer[cuda] --no-build-isolationLocally:
uv pip install . && uv pip install .[cuda] --no-build-isolationFull Changelog: v0.4.0...v0.5.0
v0.4.0: Memory-Optimal Model Loading for SFT + OSFT 💾
v0.4.0: Memory-Optimal Model Loading for SFT + OSFT 💾
Summary: This release overhauls distributed initialization into a unified prepare → wrap → finalize pipeline that applies to both standard SFT and OSFT. The new
flow keeps all non-workload ranks on meta tensors until the last moment, rehydrates parameters shard-by-shard, and drives up to an 80% reduction in peak CPU + GPU
memory use when materializing large checkpoints. Alongside the loader refactor we hardened GPT-OSS conversions, fixed OSFT corner cases, refreshed docs, and sped up
our CI/tooling so the new path is well covered.
Highlights
- ♻️ Unified three-phase distributed initialization:
prepare_model_for_fsdp2,wrap_fsdp2, andfinalize_model_initializationnow orchestrate rank-aware
loading for both SFT and OSFT so only rank 0 ever touches the full checkpoint and all other ranks stay light-weight onmetatensors. - 📉 Up to 80% lower memory footprint: Materializing each shard directly inside its FSDP2 wrapper eliminates the previous all-gather spikes, keeping both CPU
staging buffers and GPU allocations flat even on 70B-class models. - 🧠 Safer lazy-init internals: The new
fsdp2_lazy_inithelpers andModelInitializationContexttrack initialization state, repair stray tensor attributes,
and makefinalize_model_initializationidempotent for both OSFT factors and vanilla weights. - 📚 Contributor polish: Fresh distributed-initialization documentation, README callouts, and CI/test updates make it easier to reason about the pipeline while
keepingosft_memory_efficient_initfully automatic and deprecated.
Performance Improvements
Unified Three-Phase Distributed Initialization
- Phase 1 (
prepare_model_for_fsdp2) lets rank 0 load the checkpoint on CPU, record the state dict/buffers insideModelInitializationContext, and immediately free
the heavyweight module while every other rank instantiates the model onmeta. - Phase 2 (
wrap_fsdp2) now focuses solely on activation checkpointing each transformer block, building the device mesh, and callingfully_shard, so sharding
incurs zero data movement. - Phase 3 (
finalize_model_initialization) streams the stored payload back into the sharded modules—broadcasting vanilla tensors for SFT and driving the distributed
SVD path for OSFT—without ever inflating a full-model copy.
Memory-Safe OSFT Internals
- Introduced
fsdp2_lazy_inittagging so every model knows whether it is in SFT or OSFT lazy-init mode, preventing accidental re-entry and enabling clean teardown/
reinitialization. _sanitize_meta_attribute_aliasesnow repairs stray tensors (buffers/parameters that lived onmetaor CPU) before wrapping, stopping the silent OOM regressions
caused by attribute clones that FSDP2 could not shard.- OSFT models keep their dense payload exclusively on rank 0 until
set_model_state_dictreceives the distributed SVD output, so non-owned tensor references never
leak into other ranks.
Activation Checkpointing + FSDP2 Integration
- Every transformer block is automatically wrapped with
checkpoint_wrapperas part of Phase 2, dramatically shrinking activation memory and aligning the FSDP2
graph between SFT and OSFT. - Device-mesh creation, mixed-precision policies, activation checkpointing, and
fully_shardnow execute in isolation so the finalize phase can focus on populating
weights and clearing buffers.
GPT-OSS & Control-Plane Reliability
- GPT-OSS MXFP4 conversion now enforces consistent float32 tie-breaking, bias handling, and dtype casting so checkpoint loads no longer hit mismatched tensor errors.
- A dedicated CPU-friendly (Gloo) control process group carries object/broadcast traffic, eliminating NCCL stalls when dispatching OSFT contexts or checkpoint
metadata.
Developer Experience & Tooling
- Added
docs/distributed_initialization.mdplus a README callout that diagrams the three phases for contributors and documents both SFT and OSFT entry points. - The
osft_memory_efficient_initflag is officially deprecated—distributed jobs always use the lazy-init path, while single-node/non-distributed runs default to
eager loading with helpful warnings. - API wrappers and CLI plumbing emit clear warnings when legacy flags are present and provide actionable guidance for non-distributed scenarios.
Testing & CI
- Revamped
tests/test_model_initialization.py,tests/test_integration_small_models.py, and related harnesses to cover the new lazy-init state machine, GPT-2/Qwen
variants, and both SFT/OSFT permutations. - Added targeted
toxenvironments (lint-changed,format-check-changed) and taught the GitHub Actions workflows to fetch full history so lint/format checks only
touch files that changed, cutting CI runtime. - Orthogonality utilities now surface debug info immediately so regression tests fail loudly when distributed SVD inputs misbehave.
Bug Fixes & Reliability
- Non-distributed OSFT execution once again bypasses the lazy-init path and honors
fsdp2_lazy_init=False, preventing hangs when torch.distributed is absent. - Fixed dtype mismatches and bias handling while loading GPT-OSS checkpoints, avoiding corrupted experts after MXFP4 quantization.
- Progress bars now advance the epoch counter correctly after each loop, keeping CLI/UI metrics accurate.
- Tightened error handling around distributed saves, synchronized barriers, and cleanup so we do not leak memory when exceptions propagate.
Upgrade Notes
- Remove
osft_memory_efficient_initfrom configs—distributed memory-efficient initialization is automatic and the flag will be removed in v0.5.0. - Use
torchrun/distributed launches whenever you want the lazy-init benefits; non-distributed jobs should explicitly keepfsdp2_lazy_init=False. - Historical
CHANGELOG.mdentries now live in GitHub Releases (this document), so reference release tags for future diffs.
Contributors
Full Changelog: v0.3.1...v0.4.0
Let me know if you’d like any extra callouts (benchmarks, GIFs, code samples, etc.) before you publish the release.
v0.4.0a1: Memory-Optimal Model Loading for SFT + OSFT 💾
v0.4.0a1: Memory-Optimal Model Loading for SFT + OSFT 💾
Summary: This release overhauls distributed initialization into a unified prepare → wrap → finalize pipeline that applies to both standard SFT and OSFT. The new
flow keeps all non-workload ranks on meta tensors until the last moment, rehydrates parameters shard-by-shard, and drives up to an 80% reduction in peak CPU + GPU
memory use when materializing large checkpoints. Alongside the loader refactor we hardened GPT-OSS conversions, fixed OSFT corner cases, refreshed docs, and sped up
our CI/tooling so the new path is well covered.
Highlights
- ♻️ Unified three-phase distributed initialization:
prepare_model_for_fsdp2,wrap_fsdp2, andfinalize_model_initializationnow orchestrate rank-aware
loading for both SFT and OSFT so only rank 0 ever touches the full checkpoint and all other ranks stay light-weight onmetatensors. - 📉 Up to 80% lower memory footprint: Materializing each shard directly inside its FSDP2 wrapper eliminates the previous all-gather spikes, keeping both CPU
staging buffers and GPU allocations flat even on 70B-class models. - 🧠 Safer lazy-init internals: The new
fsdp2_lazy_inithelpers andModelInitializationContexttrack initialization state, repair stray tensor attributes,
and makefinalize_model_initializationidempotent for both OSFT factors and vanilla weights. - 📚 Contributor polish: Fresh distributed-initialization documentation, README callouts, and CI/test updates make it easier to reason about the pipeline while
keepingosft_memory_efficient_initfully automatic and deprecated.
Performance Improvements
Unified Three-Phase Distributed Initialization
- Phase 1 (
prepare_model_for_fsdp2) lets rank 0 load the checkpoint on CPU, record the state dict/buffers insideModelInitializationContext, and immediately free
the heavyweight module while every other rank instantiates the model onmeta. - Phase 2 (
wrap_fsdp2) now focuses solely on activation checkpointing each transformer block, building the device mesh, and callingfully_shard, so sharding
incurs zero data movement. - Phase 3 (
finalize_model_initialization) streams the stored payload back into the sharded modules—broadcasting vanilla tensors for SFT and driving the distributed
SVD path for OSFT—without ever inflating a full-model copy.
Memory-Safe OSFT Internals
- Introduced
fsdp2_lazy_inittagging so every model knows whether it is in SFT or OSFT lazy-init mode, preventing accidental re-entry and enabling clean teardown/
reinitialization. _sanitize_meta_attribute_aliasesnow repairs stray tensors (buffers/parameters that lived onmetaor CPU) before wrapping, stopping the silent OOM regressions
caused by attribute clones that FSDP2 could not shard.- OSFT models keep their dense payload exclusively on rank 0 until
set_model_state_dictreceives the distributed SVD output, so non-owned tensor references never
leak into other ranks.
Activation Checkpointing + FSDP2 Integration
- Every transformer block is automatically wrapped with
checkpoint_wrapperas part of Phase 2, dramatically shrinking activation memory and aligning the FSDP2
graph between SFT and OSFT. - Device-mesh creation, mixed-precision policies, activation checkpointing, and
fully_shardnow execute in isolation so the finalize phase can focus on populating
weights and clearing buffers.
GPT-OSS & Control-Plane Reliability
- GPT-OSS MXFP4 conversion now enforces consistent float32 tie-breaking, bias handling, and dtype casting so checkpoint loads no longer hit mismatched tensor errors.
- A dedicated CPU-friendly (Gloo) control process group carries object/broadcast traffic, eliminating NCCL stalls when dispatching OSFT contexts or checkpoint
metadata.
Developer Experience & Tooling
- Added
docs/distributed_initialization.mdplus a README callout that diagrams the three phases for contributors and documents both SFT and OSFT entry points. - The
osft_memory_efficient_initflag is officially deprecated—distributed jobs always use the lazy-init path, while single-node/non-distributed runs default to
eager loading with helpful warnings. - API wrappers and CLI plumbing emit clear warnings when legacy flags are present and provide actionable guidance for non-distributed scenarios.
Testing & CI
- Revamped
tests/test_model_initialization.py,tests/test_integration_small_models.py, and related harnesses to cover the new lazy-init state machine, GPT-2/Qwen
variants, and both SFT/OSFT permutations. - Added targeted
toxenvironments (lint-changed,format-check-changed) and taught the GitHub Actions workflows to fetch full history so lint/format checks only
touch files that changed, cutting CI runtime. - Orthogonality utilities now surface debug info immediately so regression tests fail loudly when distributed SVD inputs misbehave.
Bug Fixes & Reliability
- Non-distributed OSFT execution once again bypasses the lazy-init path and honors
fsdp2_lazy_init=False, preventing hangs when torch.distributed is absent. - Fixed dtype mismatches and bias handling while loading GPT-OSS checkpoints, avoiding corrupted experts after MXFP4 quantization.
- Progress bars now advance the epoch counter correctly after each loop, keeping CLI/UI metrics accurate.
- Tightened error handling around distributed saves, synchronized barriers, and cleanup so we do not leak memory when exceptions propagate.
Upgrade Notes
- Remove
osft_memory_efficient_initfrom configs—distributed memory-efficient initialization is automatic and the flag will be removed in v0.5.0. - Use
torchrun/distributed launches whenever you want the lazy-init benefits; non-distributed jobs should explicitly keepfsdp2_lazy_init=False. - Historical
CHANGELOG.mdentries now live in GitHub Releases (this document), so reference release tags for future diffs.
Contributors
Full Changelog: v0.3.1...v0.4.0
Let me know if you’d like any extra callouts (benchmarks, GIFs, code samples, etc.) before you publish the release.
v0.3.1: Minor bug fixes and improvements
What's Changed
- Fixes progress bar epoch counter by @RobotSail in #54
Full Changelog: v0.3.0...v0.3.1
v0.3.0: OSFT Unleashed - 3x Memory Savings & Beautiful Progress 🚀
v0.3.0: OSFT Unleashed - 3x Memory Savings & Beautiful Progress 🚀
Summary: This release delivers critical performance improvements for OSFT training with a 3x memory reduction and orthogonalization bug fixes, enhanced user experience through rich progress bars and colorful logging, comprehensive testing infrastructure for OSFT, and expanded model architecture support. The release focuses on making OSFT production-ready while improving the overall developer experience.
Highlights
- 🧠 Major OSFT Memory Optimization: 3x memory reduction through FSDP2 sharding (baseline now comparable to SFT)
- 🔧 Critical OSFT Orthogonalization Fix: Corrected distributed gradient projection for mathematical correctness
- 📊 Rich Progress Bars: Beautiful, informative training progress with real-time metrics
- 🧪 Comprehensive OSFT Testing: New regression test suite to validate orthogonality constraints
- 🦎 Enhanced Mamba Support: Added specialized convolution kernels for NVIDIA/AMD GPUs
- 📚 Modernized Documentation: Refreshed README with improved styling and clarity
Performance Improvements
OSFT Memory Optimization & Bug Fixes
Memory usage reduced by ~3x - Critical improvements to OSFT's memory footprint and correctness by @NikhilNayak-debug in #47
Memory Optimizations:
- Registered U_high/S_high/V_high as non-trainable parameters (not buffers) so FSDP2 shards them across GPUs instead of replicating
- Moved OSFT tensors under their owning Linear modules to avoid full-model all-gather
- Per-block param materialization prevents whole-model memory spikes
- Results: OSFT baseline memory reduced from ~52 GB to ~15 GB, peak from ~52 GB to ~24.6 GB (comparable to SFT)
Orthogonalization Fixes:
- Fixed distributed gradient projection to be mathematically correct across shards
- U projection (row-sharded): proper global contraction with all-reduce SUM
- V projection (row-sharded): corrected Gram matrix computation with global reduction
- Gradient projection now operates on local shards with minimal all-reduce for global correctness
New Features
Enhanced User Experience
-
Rich Progress Bars & Colorful Logging by @RobotSail in #48
- Beautiful progress bars during training and evaluation using rich console
- Real-time metrics display: epoch/step, loss, learning rate, and throughput
- Colored output with timestamps and JSON rendering for better readability
- Lazy-initialized progress lines for efficient rendering
-
OSFT Orthogonalization Test Suite by @RobotSail in #50
- Comprehensive regression test validating OSFT orthogonality constraints during training
- Monitors both gradient orthogonality (before optimizer steps) and parameter orthogonality (after optimizer steps)
- Detailed per-step reporting with aggregated violation summaries
- Added SVDModule class for better encapsulation of SVD components
- Supports distributed training validation across multiple GPUs
- Usage:
torchrun --nproc_per_node=2 regression_tests/test_osft_orthogonalization.py --model Qwen/Qwen2.5-1.5B-Instruct --num-steps 100
Model Architecture Support
-
Mamba Convolution Kernels by @RobotSail in #46
- Added mamba-ssm[causal-conv1d] dependency for specialized NVIDIA/AMD GPU kernels
- Enables efficient Mamba architecture support with hardware-optimized operations
-
Enhanced GPT-2 Family Support in #50
- Added OSFT support for GPT-2 model family
- Broadened transformer-block discovery to support more Hugging Face architectures
Torchrun Improvements
- Flexible Torchrun Arguments by @szaher in #44
nproc_per_nodenow accepts both string ("gpu") and integer valuesrdzv_idnow accepts both string and integer types- More flexible rendezvous options: choose either master address/port (static) or rendezvous endpoint
- Launch command uses hyphenated flags and conditionally builds static vs endpoint-based rendezvous
Documentation
- Modernized README by @RobotSail in #49
- Added modern badges (CI status, Python version, license)
- Integrated emojis for better navigation and visual appeal
- Updated installation commands with proper PyPI and source instructions
- Streamlined usage documentation focusing on core functionality
- Added bug reporting section with clear guidance
- Removed outdated content for cleaner, more maintainable docs
Dependencies & Infrastructure
- Improved Dependency Management by @RobotSail in #51
- Removed numpy version ceiling (aligned with numba's policy)
- Moved tox and tox-uv to optional [dev] dependencies
- Applied minimum version requirement for numba
- Streamlined end-user installations
Bug Fixes
- Fixed validation sampler epoch handling by @RobotSail in #45
- Removed incorrect epoch setting on SequentialSampler
- Improved validation data handling consistency across epochs
- Prevents unintended resets of validation sampler state
Test Infrastructure
- Expanded test environments with explicit install flow per GPU/non-GPU
- Improved tox environments with conditional CUDA/flash-attn setup
- Enhanced CI with dedicated virtual environment for consistent tooling
- Added comprehensive OSFT orthogonality regression tests
Example Usage
Training with Progress Bars
from mini_trainer.api_train import run_training
from mini_trainer.training_types import TrainingArgs, TorchrunArgs
train_args = TrainingArgs(
model_name="Qwen/Qwen2.5-1.5B-Instruct",
osft=True,
osft_rank_ratio=0.5,
...
)
# Beautiful progress bars will automatically display during training
run_training(torch_args, train_args)Running OSFT Orthogonalization Tests
torchrun --nproc_per_node=2 regression_tests/test_osft_orthogonalization.py \
--model Qwen/Qwen2.5-1.5B-Instruct \
--num-steps 100 \
--margin-deg 1.0 \
--rank-ratio 0.5Upgrade Notes
- No breaking API changes
- OSFT users will see significant memory improvements automatically
- Progress bars work best when running
train.pydirectly through torchrun (api_train.py streams byte-for-byte and may reprint progress bars) - New orthogonalization test suite available for validating OSFT training correctness
Contributors
Installation
Through Pip:
uv pip install rhai-innovation-mini-trainer && uv pip install rhai-innovation-mini-trainer[cuda] --no-build-isolationLocally:
uv pip install . && uv pip install .[cuda] --no-build-isolationFull Changelog: v0.2.1...v0.3.0
v0.2.1: Enhanced GPT-OSS Checkpointing
Mini Trainer v0.2.1 Release Notes
🚀 What's Changed
Performance Improvements
- Optimize GPT-OSS model saving with GPU chunking - Significantly faster model checkpointing by moving expert parameter quantization to GPU with smart memory management to avoid OOM errors. by @NikhilNayak-debug in #43
Dependencies
- Bump NumPy version compatibility to
<2.3for better modern environment support by @RobotSail in #41 - Move Liger kernel to CUDA dependencies - ARM64 systems cannot easily install Liger due to Triton dependencies. Moving to optional CUDA extra for consistency with other libraries like instructlab-training by @RobotSail in #42
Installation
Through Pip:
uv pip install rhai-innovation-mini-trainer && uv pip install rhai-innovation-mini-trainer[cuda] --no-build-isolationLocally:
uv pip install . && uv pip install .[cuda] --no-build-isolationFull Changelog: https://github.com/Red-Hat-AI-Innovation-Team/mini_trainer/compare/v0.2.0...v0.2.1%
v0.2.0: The GPT-OSS Release
Summary: This release introduces support for OpenAI's GPT-OSS models (20B/120B) with native MXFP4 quantization and memory-efficient training, adds Weights & Biases integration for experiment tracking, implements train/validation splits with loss reporting, and includes new data conversion tools. The release focuses on enabling efficient fine-tuning of large language models while maintaining ease of use.
Highlights
- 🚀 GPT-OSS Model Support: Full support for OpenAI's open-weight models (20B and 120B variants)
- 🧠 Memory-Efficient Training: OSFT memory optimization and flexible dtype controls
- 🖌️ Weights & Biases Integration: Effortless experiment tracking with wandb
- 📊 Train/Validation Split: Built-in validation support with loss reporting
- 🛠️ Enhanced Tooling: Pretraining conversion scripts and improved data processing
New Features
Model Support & Memory Optimization
-
GPT-OSS Model Support
- Full support for OpenAI's GPT-OSS models (20B and 120B variants)
- Native MXFP4 quantization implementation
- New
gpt_oss_utils.pymodule (430+ lines)
-
Memory-Efficient OSFT Initialization
- New
osft_memory_efficient_initflag for optimized initialization - Significant memory savings during model loading
- New
-
Training Dtype Control
- New
train_dtypeparameter for bf16/fp16 training osft_upcast_dtypefor computation precision (default: float32)osft_output_dtypefor output precision control
- New
Training Infrastructure
-
Weights & Biases Support
- New
wandb_wrapper.pymodule - Automatic logging of training/validation metrics, gradients, and system stats
- Opt-in via
--wandbCLI flag
- New
-
Train/Validation Split
- Deterministic split into train and validation shards
- New
--validation-splitargument (default 0.05) - Validation loop runs every
validation_frequencysteps - Validation loss computation and reporting
Tools & Scripts
- Data Processing Enhancements
- New
convert_to_pretrain.pyscript for dataset conversion - Improved
process_data.pywith additional functionality
- New
Dependencies & Infrastructure
- Updated transformers to
>=4.55.0 - Added liger-kernel for optimized operations
- Added kernels package for flash-attention-3 support
- Enhanced test coverage for validation and sampler behavior
Bug Fixes
- GPT-OSS checkpoint saving during SFT
- Distributed torch tests stability
- Dtype conversion edge-cases
- Default
validation_frequencyis nowNoneinstead of0 - Various test case failures and code optimizations
Example Usage
from mini_trainer.api_train import run_training
from mini_trainer.training_types import TrainingArgs, TorchrunArgs
train_args = TrainingArgs(
model_name="openai/gpt-oss-20b",
osft_memory_efficient_init=True,
train_dtype="bfloat16",
wandb=True, # Enable wandb logging
validation_split=0.05, # 5% validation split
validation_frequency=100, # Validate every 100 steps
...
)
run_training(torch_args, train_args)Upgrade Notes
No breaking API changes. Primary focus on GPT-OSS 20B model support (120B variant potentially supported but not extensively tested). WandB logging requires wandb>=0.16.
Contributors
Full Changelog: v0.1.1...v0.2.0
v0.2.0a1: GPT-OSS Support & Memory Optimizations 🧠 [Alpha 1]
👤 Authors
🚀 New Features
- GPT-OSS Model Support: Full support for OpenAI's new open-weight GPT-OSS models (20B and 120B variants) with native MXFP4 quantization implementation
- Memory-Efficient OSFT Initialization: New
osft_memory_efficient_initflag for optimized initialization of large models - Training Dtype Control: New
train_dtypeparameter for switching models to bf16/fp16 training to reduce memory usage (use sparingly as lower precision may impact results) - Pretraining Data Conversion: New
convert_to_pretrain.pyscript for converting conversation datasets to pretraining format
🛠️ API Enhancements
- OSFT Dtype Controls:
osft_upcast_dtypefor computation precision (default: float32)osft_output_dtypefor output precision control
- Enhanced Data Processing: Improved
process_data.pywith additional functionality
📦 Dependencies
- Updated transformers: Now requires
>=4.55.0 - Added liger-kernel: For optimized operations
- Added kernels package: For flash-attention-3 support
🧪 Testing & Quality
- Comprehensive GPT-OSS Testing: New regression tests for MXFP4 conversion accuracy
- Enhanced OSFT Testing: Improved dtype functionality tests
📊 Statistics
- 1,841 additions, 173 deletions across 14 files
- Major new module:
gpt_oss_utils.py(430+ lines) - Significant enhancements to core training and model setup utilities
🎯 Focus
This release primarily targets the GPT-OSS 20B model, enabling fine-tuning of OpenAI's open-weight MoE models. The memory optimizations are applicable to any large model training, with the 120B variant potentially supported but not extensively tested.
💡 GPT-OSS-20B Usage Example
For GPT-OSS 20B training, use the memory-efficient initialization and bfloat16 training dtype:
from mini_trainer.api_train import run_training
from mini_trainer.training_types import TrainingArgs, TorchrunArgs
train_args = TrainingArgs(
model_name="openai/gpt-oss-20b",
osft_memory_efficient_init=True,
train_dtype="bfloat16",
... # other training arguments
)
run_training(torch_args, train_args)📜 Full Changelog
Full Changelog: v0.1.1...v0.2.0a1
v0.1.1: Improvements & Stability Fixes
What's Changed
- add port to default rdzv_endpoint by @RobotSail in #29
- add cpu offloading during checkpointing by @RobotSail in #31
- Enables OSFT to run in multi-node distributed workloads by @RobotSail in #33
Full Changelog: v0.1.0...v0.1.1