Skip to content

kamisori-daijin/textrm

Repository files navigation

textrm: Text Generation Model with TRM

An optimized implementation of TinyRecursiveModels using the MLX framework for Apple Silicon.

This project implements a recursive transformer architecture that improves latent reasoning states through multiple cycles, achieving high performance with a significantly smaller parameter count.

Features

  • MLX Optimized: Specifically designed for Mac Unified Memory.
  • Recursive Reasoning: Implements the core $n$ latent recursions and $T$ improvement cycles from the TRM paper.
  • Deep Supervision: Training with supervision at each improvement step for better stability.
  • EMA (Exponential Moving Average): Integrated weight averaging for better generalization.

Pre-Trained Model

Usage

  1. Setup the environment

    python -m venv .venv
    source .venv/bin/activate
  2. Install requirements

    pip install -r requirements.txt 
  3. Configure the model Adjust hyperparameters in models/config.py.

  4. Train the model The training script uses MLX's efficient gradient computation and automatic hardware acceleration.

    python train.py 

    Best weights will be saved as best_model_mlx.safetensors.

  5. Run Inference Generate text using the trained model:

    python inference.py --prompt "Write a polite refusal email"

Project Structure

  • train.py: Main entry point for training.
  • inference.py: Interactive text generation.
  • models/: MLX model definitions (trm_model.py, trm_build.py).
  • training/: MLX-specific training loop and logic.
  • ema/: Exponential Moving Average implementation for MLX.
  • dataset/: Dataset loading and tokenization logic.

Acknowledgments

About

textrm: Text Generation Model with TRM

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages