Skip to content

ege-dgny/epipolar-jepa

Repository files navigation

I-JEPA: Image Joint-Embedding Predictive Architecture

A PyTorch implementation of I-JEPA (Image Joint-Embedding Predictive Architecture) using pretrained Vision Transformers (ViT-B/32) as encoders. This implementation follows the principles from the original I-JEPA paper, focusing on learning semantic representations by predicting masked image regions in the embedding space rather than pixel space.

📋 Overview

I-JEPA is a self-supervised learning method that:

  • Predicts embeddings of masked image regions rather than reconstructing pixels
  • Learns semantic representations that capture high-level image understanding
  • Uses asymmetric architecture with context and target encoders
  • Avoids representation collapse through stop-gradient and EMA updates

Key Features

  • ✅ Pretrained ViT-B/32 encoders from timm library
  • ✅ Multiple masking strategies (random, block, multi-block)
  • ✅ Lightweight transformer predictor network
  • ✅ EMA (Exponential Moving Average) for target encoder updates
  • ✅ Comprehensive visualization tools
  • ✅ Support for both CIFAR-10 and ImageNet datasets
  • ✅ Wandb integration for experiment tracking

🛠️ Installation

# Clone the repository
cd B-Jepa

# Install dependencies
pip install -r requirements.txt

Device Support

The implementation automatically detects and uses the best available device:

  • CUDA - NVIDIA GPUs (highest priority)
  • MPS - Apple Silicon M1/M2/M3 GPUs (Metal Performance Shaders)
  • CPU - Fallback for any system

You can also manually specify a device:

python train_ijepa.py --device mps    # Force MPS
python train_ijepa.py --device cuda   # Force CUDA
python train_ijepa.py --device cpu    # Force CPU

🚀 Quick Start

1. Run the Demo

The easiest way to get started is with the example script:

python example_usage.py

This will:

  • Test all model components
  • Visualize different masking strategies
  • Run a quick training demo on CIFAR-10

2. Train on CIFAR-10

For a quick training run on CIFAR-10:

python train_ijepa.py \
    --batch_size 128 \
    --num_epochs 50 \
    --lr 1e-4 \
    --use_pretrained \
    --data_dir ./data

The script will automatically detect and use the best available device (CUDA > MPS > CPU).

3. Train on ImageNet

For ImageNet training:

python train_ijepa.py \
    --batch_size 256 \
    --num_epochs 100 \
    --lr 1e-4 \
    --use_pretrained \
    --use_imagenet \
    --data_dir /path/to/imagenet \
    --wandb \
    --wandb_project ijepa_imagenet

📁 Project Structure

B-Jepa/
├── ijepa_model.py          # Core I-JEPA model implementation
├── masking.py              # Masking strategies and collators
├── train_ijepa.py          # Training script with trainer class
├── inference.py            # Comprehensive inference and analysis
├── visualize_ijepa.py      # Visualization utilities
├── example_usage.py        # Demo and testing script
├── check_device.py         # Device detection utility
├── requirements.txt        # Dependencies
└── README.md              # This file

🏗️ Architecture

Model Components

  1. Context Encoder (ViT-B/32)

    • Processes visible patches
    • Parameters are updated through backpropagation
  2. Target Encoder (ViT-B/32)

    • Processes masked patches
    • Updated via EMA or frozen (configurable)
    • No gradients flow through this encoder
  3. Predictor Network

    • Lightweight transformer (6 layers by default)
    • Predicts target embeddings from context embeddings
    • Only attached to context encoder branch

Masking Strategies

  • Random: Randomly masks individual patches
  • Block: Masks a single contiguous block
  • Multi-block: Masks multiple contiguous blocks (default)

📊 Training

Basic Training Command

python train_ijepa.py \
    --encoder_name vit_base_patch32_224 \
    --predictor_depth 6 \
    --predictor_embed_dim 384 \
    --mask_ratio 0.75 \
    --mask_strategy multi_block \
    --num_masks 4 \
    --batch_size 256 \
    --num_epochs 100 \
    --lr 1e-4 \
    --weight_decay 0.05

Key Training Parameters

  • --encoder_name: ViT model variant (default: vit_base_patch32_224)
  • --use_pretrained: Use pretrained ViT weights
  • --mask_ratio: Fraction of patches to mask (default: 0.75)
  • --mask_strategy: Masking strategy (random, block, multi_block)
  • --momentum: EMA momentum for target encoder (default: 0.996)
  • --gradient_clip: Gradient clipping value (default: 1.0)

Resume Training

python train_ijepa.py \
    --resume ./checkpoints/checkpoint_epoch_50.pt \
    [other arguments...]

🔍 Inference & Analysis

Single Image Inference

Test your trained model on a single image:

python inference.py \
    --model_path ./checkpoints/best_model.pt \
    --input /path/to/image.jpg \
    --output_dir ./results

Batch Inference

Process multiple images from a directory:

python inference.py \
    --model_path ./checkpoints/best_model.pt \
    --input /path/to/image/directory \
    --output_dir ./batch_results \
    --max_images 20

Compare Masking Strategies

See how different masking strategies perform on the same image:

python inference.py \
    --model_path ./checkpoints/best_model.pt \
    --input /path/to/image.jpg \
    --compare_strategies \
    --output_dir ./strategy_comparison

Advanced Visualization

For detailed visualizations:

python visualize_ijepa.py \
    --mode inference \
    --model_path ./checkpoints/best_model.pt \
    --image_path /path/to/image.jpg \
    --save_dir ./visualizations

📈 Model Configuration

Create Custom Model

from ijepa_model import IJEPAModel, IJEPALoss
from masking import IJEPAMaskCollator

# Initialize model
model = IJEPAModel(
    encoder_name='vit_base_patch32_224',
    predictor_depth=6,
    predictor_embed_dim=384,
    predictor_mlp_ratio=4.0,
    use_pretrained=True,
    freeze_target_encoder=False,
    momentum=0.996
)

# Create mask collator
mask_collator = IJEPAMaskCollator(
    patch_size=32,
    image_size=224,
    mask_ratio=0.75,
    mask_strategy='multi_block',
    num_masks=4
)

# Loss function
loss_fn = IJEPALoss(loss_type='l2')  # Options: 'l2', 'l1', 'smooth_l1', 'cosine'

🔬 Understanding I-JEPA

How It Works

  1. Input Processing: Images are divided into patches (32x32 pixels)
  2. Masking: A subset of patches is masked using the chosen strategy
  3. Encoding:
    • Context encoder processes visible patches
    • Target encoder processes masked patches (no gradients)
  4. Prediction: Predictor network estimates target embeddings from context
  5. Loss: Minimize distance between predicted and actual target embeddings

Key Differences from MAE

  • MAE reconstructs raw pixels → I-JEPA predicts embeddings
  • MAE focuses on low-level details → I-JEPA captures semantic information
  • MAE uses single encoder → I-JEPA uses asymmetric dual encoders

Why Predict Embeddings?

  • Focuses on semantic understanding rather than texture details
  • More efficient (lower dimensional space)
  • Better transfer learning performance
  • Avoids trivial solutions (e.g., copying nearby pixels)

📝 Experimental Results

Expected Performance

  • CIFAR-10: ~85-90% linear probe accuracy after 100 epochs
  • ImageNet: ~70-75% linear probe accuracy after 300 epochs

Training Tips

  1. Batch Size: Larger batches (256+) generally work better
  2. Learning Rate: Scale with batch size (lr = base_lr * batch_size / 256)
  3. Masking Ratio: 0.6-0.8 works well, with 0.75 being optimal
  4. Multi-block Masking: Generally performs better than random masking
  5. EMA Momentum: Start with 0.996, increase to 0.999 for longer training

Device-Specific Tips

For Apple Silicon (MPS):

  • Recommended batch size: 32-128 (depending on model size and memory)
  • MPS fallback is enabled for unsupported operations
  • Memory usage is generally lower than CUDA
  • Training speed is competitive with mid-range GPUs

For CUDA:

  • Can handle larger batch sizes (128-512)
  • Full operation support
  • Best performance for large-scale training

For CPU:

  • Use smaller batch sizes (16-64)
  • Consider reducing model size for faster iteration
  • Good for prototyping and small experiments

🤝 Contributing

Contributions are welcome! Feel free to:

  • Report bugs
  • Suggest new features
  • Submit pull requests
  • Improve documentation

📚 References

📄 License

This project is released under the MIT License.

🙏 Acknowledgments

  • Meta AI Research for the I-JEPA concept and paper
  • The timm library for pretrained ViT models
  • The PyTorch team for the excellent deep learning framework

Note: This is an educational implementation of I-JEPA. For production use, consider the official implementation or further optimizations.# e-jepa

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages