CPU-trainable Joint Embedding Predictive Architecture in Rust
No pixels predicted. No labels. No contrastive pairs. Pure structure, emerging from prediction in latent space.
cargo run --release -- train --demo # run the demo (~30s on CPU)
cargo run --release -- mask-demo # visualize block masking
cargo run --release -- info # architecture details
cargo run --release -- sanity # quick forward pass checks
cargo test # 21/21 tests
JEPA (Joint Embedding Predictive Architecture) is a self-supervised learning framework by Meta AI / Yann LeCun (CVPR 2023).
The key insight: predict representations, not pixels.
Instead of reconstructing masked pixels (like MAE), the predictor learns to output the abstract embeddings that a momentum target encoder would produce at masked positions. More semantically meaningful. No augmentation. No contrastive pairs.
Input Image
↓
Block Masking (sample target region T, context region C)
↓ ↓
Context Encoder (θ) Target Encoder (θ̄) ← EMA-updated, no grad
processes C patches processes ALL patches
↓ ↓
context_embeddings target_embeddings (ground truth)
↓
Predictor (φ) ← context_emb + positional mask tokens at target positions
↓
predicted_embeddings
Loss = MSE(predicted_embeddings, target_embeddings) ← in latent space only
┌─ Configuration ─────────────────────────────┐
│ Image: 32×32×3 │
│ Patches: 4×4 = 16 (patch_size=8) │
│ Encoder: dim=64, depth=4, heads=4 │
│ Predictor: dim=32, depth=2 │
│ EMA tau: 0.9960 → 0.9999 │
│ Training: 20ep, bs=16, lr=1e-3 │
│ Samples: 256 │
└─────────────────────────────────────────────┘
Training components initialized:
Model: JepaModel<NdArray>
Optimizer: AdamW (lr=1e-3)
EMA: tau=0.9960 → 0.9999
Dataset: 256 synthetic images generated
Demonstrating JEPA forward pass (20 steps):
Step 1 | Loss: 1.225462 | ctx_patches: 15 | tgt_patches: 1
Step 2 | Loss: 1.251864 | ctx_patches: 15 | tgt_patches: 1
Step 3 | Loss: 1.212248 | ctx_patches: 15 | tgt_patches: 1
Step 4 | Loss: 1.268265 | ctx_patches: 12 | tgt_patches: 4
Step 5 | Loss: 1.198469 | ctx_patches: 14 | tgt_patches: 2
Step 6 | Loss: 1.264925 | ctx_patches: 14 | tgt_patches: 2
Step 7 | Loss: 1.213228 | ctx_patches: 15 | tgt_patches: 1
Step 8 | Loss: 1.234686 | ctx_patches: 15 | tgt_patches: 1
Step 9 | Loss: 1.240172 | ctx_patches: 12 | tgt_patches: 4
Step 10 | Loss: 1.249817 | ctx_patches: 12 | tgt_patches: 4
Step 11 | Loss: 1.183829 | ctx_patches: 12 | tgt_patches: 4
Step 12 | Loss: 1.181217 | ctx_patches: 14 | tgt_patches: 2
Step 13 | Loss: 1.293803 | ctx_patches: 14 | tgt_patches: 2
Step 14 | Loss: 1.269303 | ctx_patches: 15 | tgt_patches: 1
Step 15 | Loss: 1.383039 | ctx_patches: 14 | tgt_patches: 2
Step 16 | Loss: 1.261974 | ctx_patches: 14 | tgt_patches: 2
Step 17 | Loss: 1.262577 | ctx_patches: 15 | tgt_patches: 1
Step 18 | Loss: 1.262430 | ctx_patches: 14 | tgt_patches: 2
Step 19 | Loss: 1.234146 | ctx_patches: 12 | tgt_patches: 4
Step 20 | Loss: 1.308062 | ctx_patches: 15 | tgt_patches: 1
┌─ Training Loss ─────────────────────────────────────────┐
│ max: 1.383039 │
│ ███ │
│ ███ │
│ ███ │
│ ███ │
│ ███ │
│ ███ ███ │
│ ███ ███ ███ │
│ ███ ███ ███ │
│ ███ ███ ██████████████████ ███ │
│ ███ ███ ███ ███ ██████████████████ ███ │
│ ███ ███ ███ █████████ ████████████████████████ │
│ ██████ ███ ███ █████████ ████████████████████████ │
│ ████████████ ███████████████ ████████████████████████ │
│ ██████████████████████████████ ████████████████████████ │
│ ████████████████████████████████████████████████████████████ │
│ min: 1.181217 │
│ epochs: 1→20 │
└──────────────────────────────────────────────────────────┘
JEPA mechanism demonstrated successfully.
What just happened:
1. Images split into 4×4 = 16 patches
2. Block masking sampled: ~12 context, ~4 target patches
3. Context encoder processed visible patches
4. Predictor hallucinated target embeddings from context
5. Target encoder produced ground-truth embeddings (all patches)
6. MSE loss = difference between predicted and target embeddings
Full training with backprop would minimize this loss over epochs.
EMA update would slowly move target encoder toward context encoder.
Result: representations that capture semantic structure.
Block Masking Demo
JEPA samples rectangular blocks from the patch grid.
Context encoder sees everything EXCEPT target blocks.
Predictor must predict what target encoder sees at target positions.
Showing 5 different mask samples for a 4×4 patch grid:
Sample 1: Sample 4:
┌────────┐ ┌────────┐
│C C C C │ │C C C C │
│C C C C │ │T T C C │
│C T C C │ │T T C C │
│C C C C │ │C C C C │
└────────┘ └────────┘
ctx: 15 tgt: 1 ctx: 12 tgt: 4
Sample 2: Sample 5:
┌────────┐ ┌────────┐
│C C C C │ │C C C C │
│T C C C │ │C C T C │
│C C C C │ │C C T C │
│C C C C │ │C C C C │
└────────┘ └────────┘
ctx: 15 tgt: 1 ctx: 14 tgt: 2
Key insight:
The predictor sees context (C) but must predict target (T) embeddings.
It knows WHERE to predict (positional embedding in mask token).
It learns WHAT to predict from structure in context patches.
Architecture Details:
─────────────────────
Context Encoder (student):
ViT-Nano: embed=64, depth=4, heads=4
Receives gradients via backprop
Only sees context patches (not target)
Target Encoder (teacher):
Same architecture as context encoder
NO gradient updates
Updated via EMA: θ̄ ← 0.9960·θ̄ + 0.0040·θ
Provides ground-truth targets for loss
Predictor:
Narrower ViT: embed=32, depth=2
Input: context embeddings + positional mask tokens
Output: predicted embeddings at target positions
Loss: MSE(predicted, target_encoder_output)
Entirely in latent embedding space
No pixel reconstruction — semantics only
Collapse Prevention:
EMA creates 'moving target' that the student must chase
Constant outputs never minimize the loss because
target_encoder(x) ≠ target_encoder(y) for x≠y
Designed to run on CPU. Demoable in ~30 seconds.
| Component | Config | Notes |
|---|---|---|
| Image size | 32×32 | CIFAR-10 compatible |
| Patch size | 8×8 | 4×4 = 16 patches per image |
| Encoder embed dim | 64 | vs 768 for ViT-L |
| Encoder depth | 4 transformer blocks | vs 12 for ViT-L |
| Encoder heads | 4 | |
| Predictor embed dim | 32 | narrower than encoder |
| Predictor depth | 2 | shallower than encoder |
| Total params | ~1.1M | vs 632M for ViT-H |
| EMA tau | 0.996 → 0.9999 | cosine schedule |
- Processes only visible patches (context, not target positions)
- Gets gradients via backprop on every step
- The network that actually learns representations
- Same ViT architecture as context encoder
- Processes all patches
- Zero gradient updates — weights moved only by EMA:
θ̄ ← τ · θ̄ + (1 - τ) · θ, τ: 0.996 → 0.9999 - Acts as a slowly-moving "oracle" the predictor must match
- Lightweight transformer (half the width, half the depth)
- Input: context embeddings concatenated with mask tokens at target positions
- Mask token = shared learnable vector + positional embedding of the target patch
- This tells the predictor where to predict, not what
- Output: predicted embedding at each target position
Self-supervised methods face representation collapse — the model learns to output the same constant embedding for every input, trivially minimizing any similarity loss.
JEPA's defense:
Suppose context_encoder collapses: f(x) = constant for all x.
Then predictor_output = constant.
But target_encoder(x) = EMA of {recent θ checkpoints applied to x}
≠ constant (because x varies and EMA ≠ current θ)
So MSE(constant, target_encoder(x)) stays high for varied x.
The only way to minimize loss: produce varied, informative embeddings.
The EMA creates a "moving target" that the student must genuinely chase. Too slow (τ → 1): stable but student can't update target. Too fast (τ → 0): target oscillates with student, collapse risk. Schedule 0.996 → 0.9999 is empirically stable.
4×4 patch grid example:
┌────────┐
│C C C C │ C = context (visible to context encoder)
│C C T C │ T = target (predicted by predictor,
│C T T C │ grounded by target encoder)
│C C C C │
└────────┘
Scale: target block = 15–40% of patches
Aspect ratio: 0.75–1.5
Min context: 8 patches always retained
Large rectangular blocks force semantic reasoning, not local texture interpolation. If you could infer T from adjacent C patches trivially, the predictor wouldn't need to learn anything about global structure.
L = (1/M) Σᵢ ||ŝᵢ - sᵢ||²
where:
ŝᵢ = predictor output at target patch i (predicted)
sᵢ = target_encoder output at target patch i (ground truth, no grad)
M = total target patches in batch
MSE in embedding space. No reconstruction, no negatives, no augmentation. Just: "predict what the oracle would output here."
Also implemented: L1 loss (V-JEPA 2 style) and cosine similarity loss.
src/
├── main.rs CLI: train, mask-demo, info, sanity
├── model/
│ ├── patch_embed.rs Image [B,C,H,W] → patches [B,N,E]
│ ├── attention.rs Multi-head self-attention (Q,K,V projections)
│ ├── transformer.rs Pre-norm transformer block + stack
│ ├── vit.rs Vision Transformer (full & masked forward)
│ └── jepa.rs JepaModel: context_encoder + target_encoder + predictor
├── data/
│ ├── synthetic.rs 6 types of synthetic images, zero download
│ └── batch.rs Vec<f32> → Tensor<B,4> utilities
├── train/
│ ├── masking.rs Block mask sampler (configurable scale/aspect)
│ ├── loss.rs MSE, L1, cosine losses in latent space
│ ├── ema.rs EMA tau scheduler (cosine, 0.996 → 0.9999)
│ └── trainer.rs Training loop structure + demo runner
└── demo/
└── visualize.rs ASCII loss curves, mask grids, heatmaps
research/
├── 01-jepa-architecture.md Full architecture breakdown + diagrams
├── 02-implementation-decisions.md Why burn, why nano scale, data flow
└── 03-training-dynamics.md Math of collapse prevention, EMA dynamics
# Quick demo — 20 forward steps, shows full mechanism, ~30s on CPU
cargo run --release -- train --demo
# Full forward-pass run (50 steps, 1000 samples)
cargo run --release -- train --epochs 50 --samples 1000
# Block masking visualization
cargo run --release -- mask-demo
# Architecture + parameter details
cargo run --release -- info
# Quick forward pass sanity check
cargo run --release -- sanity
# Full test suite (21 tests)
cargo testStack:
- Pure Rust, burn 0.16 ML framework
NdArraybackend — no CUDA, no libtorch, no C++ deps, runs everywhere- 21/21 unit tests pass
Extending to full training with backprop:
// Swap NdArray for Autodiff<NdArray> to get gradients:
type TrainBackend = burn::backend::Autodiff<NdArray<f32>>;
// Then use burn's optimizer:
let optimizer = AdamConfig::new().with_weight_decay(...).init();
let grads = loss.backward();
let grads = GradientsParams::from_grads(grads, &model);
model = optimizer.step(lr, model, grads);
// EMA update after each step:
// target_encoder ← tau * target_encoder + (1-tau) * context_encoderSwapping to GPU:
# In Cargo.toml, change backend:
burn = { version = "0.16", features = ["wgpu", "autodiff"] }type Backend = burn::backend::Autodiff<burn::backend::Wgpu>;Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture Assran et al., CVPR 2023, Meta AI Research