This project implements an autoregressive transformer model for generating MNIST-like handwritten digits. The model learns to predict and generate images one patch at a time, using a transformer architecture. The model is trained on the MNIST dataset and learns to generate new handwritten digits that follow similar patterns.
model.py: Contains the transformer model architecturedata.py: Handles data loading and preprocessingtrain.py: Main training scriptutils.py: Utility functions for visualization and processingconfig.py: Configuration parameters for the model and trainingbuild_codebook.py: Script for building the image token codebookgenerated_images/: Directory containing generated imagesdata/: Directory containing training data- Model checkpoints:
best_pixel_transformer.pth: Best performing model checkpointfinal_pixel_transformer.pth: Final model checkpointpixel_transformer.pth: Latest model checkpoint
codebook.pkl: Pre-computed codebook for image tokenization
- Image tokenization using patch-based approach with K-means clustering for codebook generation
- Autoregressive transformer architecture
- Training on MNIST handwritten digits dataset
- Generation of new MNIST-like handwritten digits
- Configurable model parameters
- Progress tracking and visualization
- Create and activate a virtual environment (recommended):
python -m venv venv
source venv/bin/activate - Install dependencies:
pip install -r requirements.txt- Build the codebook (if not using pre-computed one):
python build_codebook.py- Run training:
python train.pyThe model works by:
- Breaking down MNIST images into patches
- Converting patches into tokens using a K-means clustering based codebook
- Using a transformer to predict the next patch in the sequence
- Generating new MNIST-like digits autoregressively
- Python 3.8+
- PyTorch 2.0.0+
- torchvision 0.15.0+
- numpy 1.21.0+
- matplotlib 3.4.0+
- Pillow 8.0.0+
- tqdm 4.65.0+
- einops 0.6.0+
See requirements.txt for full list of dependencies.