A PyTorch implementation of I-JEPA (Image Joint-Embedding Predictive Architecture) using pretrained Vision Transformers (ViT-B/32) as encoders. This implementation follows the principles from the original I-JEPA paper, focusing on learning semantic representations by predicting masked image regions in the embedding space rather than pixel space.
I-JEPA is a self-supervised learning method that:
- Predicts embeddings of masked image regions rather than reconstructing pixels
- Learns semantic representations that capture high-level image understanding
- Uses asymmetric architecture with context and target encoders
- Avoids representation collapse through stop-gradient and EMA updates
- ✅ Pretrained ViT-B/32 encoders from
timmlibrary - ✅ Multiple masking strategies (random, block, multi-block)
- ✅ Lightweight transformer predictor network
- ✅ EMA (Exponential Moving Average) for target encoder updates
- ✅ Comprehensive visualization tools
- ✅ Support for both CIFAR-10 and ImageNet datasets
- ✅ Wandb integration for experiment tracking
# Clone the repository
cd B-Jepa
# Install dependencies
pip install -r requirements.txtThe implementation automatically detects and uses the best available device:
- CUDA - NVIDIA GPUs (highest priority)
- MPS - Apple Silicon M1/M2/M3 GPUs (Metal Performance Shaders)
- CPU - Fallback for any system
You can also manually specify a device:
python train_ijepa.py --device mps # Force MPS
python train_ijepa.py --device cuda # Force CUDA
python train_ijepa.py --device cpu # Force CPUThe easiest way to get started is with the example script:
python example_usage.pyThis will:
- Test all model components
- Visualize different masking strategies
- Run a quick training demo on CIFAR-10
For a quick training run on CIFAR-10:
python train_ijepa.py \
--batch_size 128 \
--num_epochs 50 \
--lr 1e-4 \
--use_pretrained \
--data_dir ./dataThe script will automatically detect and use the best available device (CUDA > MPS > CPU).
For ImageNet training:
python train_ijepa.py \
--batch_size 256 \
--num_epochs 100 \
--lr 1e-4 \
--use_pretrained \
--use_imagenet \
--data_dir /path/to/imagenet \
--wandb \
--wandb_project ijepa_imagenetB-Jepa/
├── ijepa_model.py # Core I-JEPA model implementation
├── masking.py # Masking strategies and collators
├── train_ijepa.py # Training script with trainer class
├── inference.py # Comprehensive inference and analysis
├── visualize_ijepa.py # Visualization utilities
├── example_usage.py # Demo and testing script
├── check_device.py # Device detection utility
├── requirements.txt # Dependencies
└── README.md # This file
-
Context Encoder (ViT-B/32)
- Processes visible patches
- Parameters are updated through backpropagation
-
Target Encoder (ViT-B/32)
- Processes masked patches
- Updated via EMA or frozen (configurable)
- No gradients flow through this encoder
-
Predictor Network
- Lightweight transformer (6 layers by default)
- Predicts target embeddings from context embeddings
- Only attached to context encoder branch
- Random: Randomly masks individual patches
- Block: Masks a single contiguous block
- Multi-block: Masks multiple contiguous blocks (default)
python train_ijepa.py \
--encoder_name vit_base_patch32_224 \
--predictor_depth 6 \
--predictor_embed_dim 384 \
--mask_ratio 0.75 \
--mask_strategy multi_block \
--num_masks 4 \
--batch_size 256 \
--num_epochs 100 \
--lr 1e-4 \
--weight_decay 0.05--encoder_name: ViT model variant (default:vit_base_patch32_224)--use_pretrained: Use pretrained ViT weights--mask_ratio: Fraction of patches to mask (default: 0.75)--mask_strategy: Masking strategy (random,block,multi_block)--momentum: EMA momentum for target encoder (default: 0.996)--gradient_clip: Gradient clipping value (default: 1.0)
python train_ijepa.py \
--resume ./checkpoints/checkpoint_epoch_50.pt \
[other arguments...]Test your trained model on a single image:
python inference.py \
--model_path ./checkpoints/best_model.pt \
--input /path/to/image.jpg \
--output_dir ./resultsProcess multiple images from a directory:
python inference.py \
--model_path ./checkpoints/best_model.pt \
--input /path/to/image/directory \
--output_dir ./batch_results \
--max_images 20See how different masking strategies perform on the same image:
python inference.py \
--model_path ./checkpoints/best_model.pt \
--input /path/to/image.jpg \
--compare_strategies \
--output_dir ./strategy_comparisonFor detailed visualizations:
python visualize_ijepa.py \
--mode inference \
--model_path ./checkpoints/best_model.pt \
--image_path /path/to/image.jpg \
--save_dir ./visualizationsfrom ijepa_model import IJEPAModel, IJEPALoss
from masking import IJEPAMaskCollator
# Initialize model
model = IJEPAModel(
encoder_name='vit_base_patch32_224',
predictor_depth=6,
predictor_embed_dim=384,
predictor_mlp_ratio=4.0,
use_pretrained=True,
freeze_target_encoder=False,
momentum=0.996
)
# Create mask collator
mask_collator = IJEPAMaskCollator(
patch_size=32,
image_size=224,
mask_ratio=0.75,
mask_strategy='multi_block',
num_masks=4
)
# Loss function
loss_fn = IJEPALoss(loss_type='l2') # Options: 'l2', 'l1', 'smooth_l1', 'cosine'- Input Processing: Images are divided into patches (32x32 pixels)
- Masking: A subset of patches is masked using the chosen strategy
- Encoding:
- Context encoder processes visible patches
- Target encoder processes masked patches (no gradients)
- Prediction: Predictor network estimates target embeddings from context
- Loss: Minimize distance between predicted and actual target embeddings
- MAE reconstructs raw pixels → I-JEPA predicts embeddings
- MAE focuses on low-level details → I-JEPA captures semantic information
- MAE uses single encoder → I-JEPA uses asymmetric dual encoders
- Focuses on semantic understanding rather than texture details
- More efficient (lower dimensional space)
- Better transfer learning performance
- Avoids trivial solutions (e.g., copying nearby pixels)
- CIFAR-10: ~85-90% linear probe accuracy after 100 epochs
- ImageNet: ~70-75% linear probe accuracy after 300 epochs
- Batch Size: Larger batches (256+) generally work better
- Learning Rate: Scale with batch size (lr = base_lr * batch_size / 256)
- Masking Ratio: 0.6-0.8 works well, with 0.75 being optimal
- Multi-block Masking: Generally performs better than random masking
- EMA Momentum: Start with 0.996, increase to 0.999 for longer training
For Apple Silicon (MPS):
- Recommended batch size: 32-128 (depending on model size and memory)
- MPS fallback is enabled for unsupported operations
- Memory usage is generally lower than CUDA
- Training speed is competitive with mid-range GPUs
For CUDA:
- Can handle larger batch sizes (128-512)
- Full operation support
- Best performance for large-scale training
For CPU:
- Use smaller batch sizes (16-64)
- Consider reducing model size for faster iteration
- Good for prototyping and small experiments
Contributions are welcome! Feel free to:
- Report bugs
- Suggest new features
- Submit pull requests
- Improve documentation
This project is released under the MIT License.
- Meta AI Research for the I-JEPA concept and paper
- The
timmlibrary for pretrained ViT models - The PyTorch team for the excellent deep learning framework
Note: This is an educational implementation of I-JEPA. For production use, consider the official implementation or further optimizations.# e-jepa