Skip to content

Gonzih/feral-jepa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Feral JEPA

CPU-trainable Joint Embedding Predictive Architecture in Rust

No pixels predicted. No labels. No contrastive pairs. Pure structure, emerging from prediction in latent space.

cargo run --release -- train --demo     # run the demo (~30s on CPU)
cargo run --release -- mask-demo        # visualize block masking
cargo run --release -- info             # architecture details
cargo run --release -- sanity           # quick forward pass checks
cargo test                              # 21/21 tests

What Is JEPA?

JEPA (Joint Embedding Predictive Architecture) is a self-supervised learning framework by Meta AI / Yann LeCun (CVPR 2023).

The key insight: predict representations, not pixels.

Instead of reconstructing masked pixels (like MAE), the predictor learns to output the abstract embeddings that a momentum target encoder would produce at masked positions. More semantically meaningful. No augmentation. No contrastive pairs.

Input Image
    ↓
Block Masking (sample target region T, context region C)
    ↓                        ↓
Context Encoder (θ)         Target Encoder (θ̄)   ← EMA-updated, no grad
 processes C patches          processes ALL patches
    ↓                        ↓
 context_embeddings          target_embeddings  (ground truth)
    ↓
Predictor (φ) ← context_emb + positional mask tokens at target positions
    ↓
 predicted_embeddings

Loss = MSE(predicted_embeddings, target_embeddings)   ← in latent space only

Demo Output

cargo run --release -- train --demo

┌─ Configuration ─────────────────────────────┐
│  Image: 32×32×3                              │
│  Patches: 4×4 = 16 (patch_size=8)          │
│  Encoder: dim=64, depth=4, heads=4          │
│  Predictor: dim=32, depth=2                  │
│  EMA tau: 0.9960 → 0.9999                   │
│  Training: 20ep, bs=16, lr=1e-3             │
│  Samples: 256                                │
└─────────────────────────────────────────────┘

Training components initialized:
  Model: JepaModel<NdArray>
  Optimizer: AdamW (lr=1e-3)
  EMA: tau=0.9960 → 0.9999

  Dataset: 256 synthetic images generated

Demonstrating JEPA forward pass (20 steps):

  Step  1 | Loss: 1.225462 | ctx_patches: 15 | tgt_patches: 1
  Step  2 | Loss: 1.251864 | ctx_patches: 15 | tgt_patches: 1
  Step  3 | Loss: 1.212248 | ctx_patches: 15 | tgt_patches: 1
  Step  4 | Loss: 1.268265 | ctx_patches: 12 | tgt_patches: 4
  Step  5 | Loss: 1.198469 | ctx_patches: 14 | tgt_patches: 2
  Step  6 | Loss: 1.264925 | ctx_patches: 14 | tgt_patches: 2
  Step  7 | Loss: 1.213228 | ctx_patches: 15 | tgt_patches: 1
  Step  8 | Loss: 1.234686 | ctx_patches: 15 | tgt_patches: 1
  Step  9 | Loss: 1.240172 | ctx_patches: 12 | tgt_patches: 4
  Step 10 | Loss: 1.249817 | ctx_patches: 12 | tgt_patches: 4
  Step 11 | Loss: 1.183829 | ctx_patches: 12 | tgt_patches: 4
  Step 12 | Loss: 1.181217 | ctx_patches: 14 | tgt_patches: 2
  Step 13 | Loss: 1.293803 | ctx_patches: 14 | tgt_patches: 2
  Step 14 | Loss: 1.269303 | ctx_patches: 15 | tgt_patches: 1
  Step 15 | Loss: 1.383039 | ctx_patches: 14 | tgt_patches: 2
  Step 16 | Loss: 1.261974 | ctx_patches: 14 | tgt_patches: 2
  Step 17 | Loss: 1.262577 | ctx_patches: 15 | tgt_patches: 1
  Step 18 | Loss: 1.262430 | ctx_patches: 14 | tgt_patches: 2
  Step 19 | Loss: 1.234146 | ctx_patches: 12 | tgt_patches: 4
  Step 20 | Loss: 1.308062 | ctx_patches: 15 | tgt_patches: 1

┌─ Training Loss ─────────────────────────────────────────┐
│ max: 1.383039                                            │
│                                           ███                │
│                                           ███                │
│                                           ███                │
│                                           ███                │
│                                           ███                │
│                                           ███            ███ │
│                                     ███   ███            ███ │
│                                     ███   ███            ███ │
│          ███   ███                  ██████████████████   ███ │
│    ███   ███   ███         ███      ██████████████████   ███ │
│    ███   ███   ███   █████████      ████████████████████████ │
│ ██████   ███   ███   █████████      ████████████████████████ │
│ ████████████   ███████████████      ████████████████████████ │
│ ██████████████████████████████      ████████████████████████ │
│ ████████████████████████████████████████████████████████████ │
│ min: 1.181217                                            │
│ epochs: 1→20                                         │
└──────────────────────────────────────────────────────────┘

JEPA mechanism demonstrated successfully.

What just happened:
  1. Images split into 4×4 = 16 patches
  2. Block masking sampled: ~12 context, ~4 target patches
  3. Context encoder processed visible patches
  4. Predictor hallucinated target embeddings from context
  5. Target encoder produced ground-truth embeddings (all patches)
  6. MSE loss = difference between predicted and target embeddings

Full training with backprop would minimize this loss over epochs.
EMA update would slowly move target encoder toward context encoder.
Result: representations that capture semantic structure.

cargo run --release -- mask-demo

Block Masking Demo

JEPA samples rectangular blocks from the patch grid.
Context encoder sees everything EXCEPT target blocks.
Predictor must predict what target encoder sees at target positions.

Showing 5 different mask samples for a 4×4 patch grid:

Sample 1:                 Sample 4:
┌────────┐                ┌────────┐
│C C C C │                │C C C C │
│C C C C │                │T T C C │
│C T C C │                │T T C C │
│C C C C │                │C C C C │
└────────┘                └────────┘
  ctx: 15  tgt: 1           ctx: 12  tgt: 4

Sample 2:                 Sample 5:
┌────────┐                ┌────────┐
│C C C C │                │C C C C │
│T C C C │                │C C T C │
│C C C C │                │C C T C │
│C C C C │                │C C C C │
└────────┘                └────────┘
  ctx: 15  tgt: 1           ctx: 14  tgt: 2

Key insight:
  The predictor sees context (C) but must predict target (T) embeddings.
  It knows WHERE to predict (positional embedding in mask token).
  It learns WHAT to predict from structure in context patches.

cargo run --release -- info

Architecture Details:
─────────────────────
Context Encoder (student):
  ViT-Nano: embed=64, depth=4, heads=4
  Receives gradients via backprop
  Only sees context patches (not target)

Target Encoder (teacher):
  Same architecture as context encoder
  NO gradient updates
  Updated via EMA: θ̄ ← 0.9960·θ̄ + 0.0040·θ
  Provides ground-truth targets for loss

Predictor:
  Narrower ViT: embed=32, depth=2
  Input: context embeddings + positional mask tokens
  Output: predicted embeddings at target positions

Loss: MSE(predicted, target_encoder_output)
  Entirely in latent embedding space
  No pixel reconstruction — semantics only

Collapse Prevention:
  EMA creates 'moving target' that the student must chase
  Constant outputs never minimize the loss because
  target_encoder(x) ≠ target_encoder(y) for x≠y

Architecture: JEPA-Nano

Designed to run on CPU. Demoable in ~30 seconds.

Component Config Notes
Image size 32×32 CIFAR-10 compatible
Patch size 8×8 4×4 = 16 patches per image
Encoder embed dim 64 vs 768 for ViT-L
Encoder depth 4 transformer blocks vs 12 for ViT-L
Encoder heads 4
Predictor embed dim 32 narrower than encoder
Predictor depth 2 shallower than encoder
Total params ~1.1M vs 632M for ViT-H
EMA tau 0.996 → 0.9999 cosine schedule

The Three-Network System

Context Encoder (Student, θ)

  • Processes only visible patches (context, not target positions)
  • Gets gradients via backprop on every step
  • The network that actually learns representations

Target Encoder (Teacher, θ̄)

  • Same ViT architecture as context encoder
  • Processes all patches
  • Zero gradient updates — weights moved only by EMA:
    θ̄ ← τ · θ̄ + (1 - τ) · θ,   τ: 0.996 → 0.9999
    
  • Acts as a slowly-moving "oracle" the predictor must match

Predictor (φ)

  • Lightweight transformer (half the width, half the depth)
  • Input: context embeddings concatenated with mask tokens at target positions
  • Mask token = shared learnable vector + positional embedding of the target patch
  • This tells the predictor where to predict, not what
  • Output: predicted embedding at each target position

Why EMA Prevents Collapse

Self-supervised methods face representation collapse — the model learns to output the same constant embedding for every input, trivially minimizing any similarity loss.

JEPA's defense:

Suppose context_encoder collapses: f(x) = constant for all x.

Then predictor_output = constant.

But target_encoder(x) = EMA of {recent θ checkpoints applied to x}
                       ≠ constant  (because x varies and EMA ≠ current θ)

So MSE(constant, target_encoder(x)) stays high for varied x.
The only way to minimize loss: produce varied, informative embeddings.

The EMA creates a "moving target" that the student must genuinely chase. Too slow (τ → 1): stable but student can't update target. Too fast (τ → 0): target oscillates with student, collapse risk. Schedule 0.996 → 0.9999 is empirically stable.


Block Masking

4×4 patch grid example:

┌────────┐
│C C C C │   C = context (visible to context encoder)
│C C T C │   T = target  (predicted by predictor,
│C T T C │              grounded by target encoder)
│C C C C │
└────────┘

Scale: target block = 15–40% of patches
Aspect ratio: 0.75–1.5
Min context: 8 patches always retained

Large rectangular blocks force semantic reasoning, not local texture interpolation. If you could infer T from adjacent C patches trivially, the predictor wouldn't need to learn anything about global structure.


Loss

L = (1/M) Σᵢ ||ŝᵢ - sᵢ||²

where:
  ŝᵢ = predictor output at target patch i       (predicted)
  sᵢ = target_encoder output at target patch i  (ground truth, no grad)
  M  = total target patches in batch

MSE in embedding space. No reconstruction, no negatives, no augmentation. Just: "predict what the oracle would output here."

Also implemented: L1 loss (V-JEPA 2 style) and cosine similarity loss.


Project Structure

src/
├── main.rs                CLI: train, mask-demo, info, sanity
├── model/
│   ├── patch_embed.rs     Image [B,C,H,W] → patches [B,N,E]
│   ├── attention.rs       Multi-head self-attention (Q,K,V projections)
│   ├── transformer.rs     Pre-norm transformer block + stack
│   ├── vit.rs             Vision Transformer (full & masked forward)
│   └── jepa.rs            JepaModel: context_encoder + target_encoder + predictor
├── data/
│   ├── synthetic.rs       6 types of synthetic images, zero download
│   └── batch.rs           Vec<f32> → Tensor<B,4> utilities
├── train/
│   ├── masking.rs         Block mask sampler (configurable scale/aspect)
│   ├── loss.rs            MSE, L1, cosine losses in latent space
│   ├── ema.rs             EMA tau scheduler (cosine, 0.996 → 0.9999)
│   └── trainer.rs         Training loop structure + demo runner
└── demo/
    └── visualize.rs       ASCII loss curves, mask grids, heatmaps

research/
├── 01-jepa-architecture.md       Full architecture breakdown + diagrams
├── 02-implementation-decisions.md  Why burn, why nano scale, data flow
└── 03-training-dynamics.md       Math of collapse prevention, EMA dynamics

Running

# Quick demo — 20 forward steps, shows full mechanism, ~30s on CPU
cargo run --release -- train --demo

# Full forward-pass run (50 steps, 1000 samples)
cargo run --release -- train --epochs 50 --samples 1000

# Block masking visualization
cargo run --release -- mask-demo

# Architecture + parameter details
cargo run --release -- info

# Quick forward pass sanity check
cargo run --release -- sanity

# Full test suite (21 tests)
cargo test

Implementation Notes

Stack:

  • Pure Rust, burn 0.16 ML framework
  • NdArray backend — no CUDA, no libtorch, no C++ deps, runs everywhere
  • 21/21 unit tests pass

Extending to full training with backprop:

// Swap NdArray for Autodiff<NdArray> to get gradients:
type TrainBackend = burn::backend::Autodiff<NdArray<f32>>;

// Then use burn's optimizer:
let optimizer = AdamConfig::new().with_weight_decay(...).init();
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optimizer.step(lr, model, grads);

// EMA update after each step:
// target_encoder ← tau * target_encoder + (1-tau) * context_encoder

Swapping to GPU:

# In Cargo.toml, change backend:
burn = { version = "0.16", features = ["wgpu", "autodiff"] }
type Backend = burn::backend::Autodiff<burn::backend::Wgpu>;

Paper

Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture Assran et al., CVPR 2023, Meta AI Research

About

CPU-trainable JEPA (Joint Embedding Predictive Architecture) in Rust. Predict in latent space, not pixels. No labels, no contrastive pairs.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages