An optimized implementation of TinyRecursiveModels using the MLX framework for Apple Silicon.
This project implements a recursive transformer architecture that improves latent reasoning states through multiple cycles, achieving high performance with a significantly smaller parameter count.
- MLX Optimized: Specifically designed for Mac Unified Memory.
-
Recursive Reasoning: Implements the core
$n$ latent recursions and$T$ improvement cycles from the TRM paper. - Deep Supervision: Training with supervision at each improvement step for better stability.
- EMA (Exponential Moving Average): Integrated weight averaging for better generalization.
- v1.0: https://huggingface.co/Kamisori-daijin/textrm-28M-bizmail
- v1.5: https://huggingface.co/Kamisori-daijin/textrm1.5-25M-bizmail
-
Setup the environment
python -m venv .venv source .venv/bin/activate -
Install requirements
pip install -r requirements.txt
-
Configure the model Adjust hyperparameters in
models/config.py. -
Train the model The training script uses MLX's efficient gradient computation and automatic hardware acceleration.
python train.py
Best weights will be saved as
best_model_mlx.safetensors. -
Run Inference Generate text using the trained model:
python inference.py --prompt "Write a polite refusal email"
train.py: Main entry point for training.inference.py: Interactive text generation.models/: MLX model definitions (trm_model.py,trm_build.py).training/: MLX-specific training loop and logic.ema/: Exponential Moving Average implementation for MLX.dataset/: Dataset loading and tokenization logic.