V-JEPA 2 (Video Joint-Embedding Predictive Architecture) is a video world model that learns spatiotemporal dynamics by predicting future frame representations in latent space. It's a full-fledged world model that learns "how the world changes over time" from video data without requiring action labels.
Temporal dynamics learning without actions:
- Predicts future frame representations from past context
- Learns motion, object dynamics, and temporal patterns
- Uses spatiotemporal transformer architecture
- EMA target encoder for stable training
- Operates in representation space, not pixels
V-JEPA 2 learns temporal dynamics:
z_t+1 = f(z_t, z_t-1, ..., z_t-k)
Where:
z_t: Representation at time tf: Learned dynamics function (predictor)- Predicts future representations from past context
This is exactly what a world model does, but without explicit actions!
V-JEPA 2 extends I-JEPA to the temporal domain:
# Context encoder (trainable)
z_context = encoder(frames[:t])
# Target encoder (EMA, frozen during forward)
z_target = target_encoder(frames[t+1:])
# Predictor (trainable)
z_pred = predictor(z_context, target_positions_in_time)
# Loss in representation space
loss = ||z_pred - z_target||²
Key differences from I-JEPA:
- Temporal masking instead of spatial
- Spatiotemporal attention instead of spatial-only
- Predicts future, not just masked regions
- Models change over time
V-JEPA 2 uses factorized space-time attention:
Spatial Attention (within frames):
# Attend to different spatial locations in same frame
Q, K, V = frame_patches
spatial_attn = Attention(Q, K, V) # Within frame
Temporal Attention (across frames):
# Attend to same spatial location across time
Q, K, V = temporal_sequence
temporal_attn = Attention(Q, K, V) # Across frames
This factorization:
- More efficient than full 3D attention
- Learns spatial and temporal structure separately
- Better inductive bias for videos
- Scales to longer sequences
Given a video sequence V = {f_1, f_2, ..., f_T}:
-
Sample context and target frames:
- Context frames: {f_1, ..., f_t} (past)
- Target frames: {f_t+1, ..., f_T} (future)
-
Encode context:
z_c = Encoder(frames[:t]) -
Predict future representations:
ẑ_f = Predictor(z_c, future_time_steps) -
Encode actual future:
z_f = TargetEncoder(frames[t+1:]) -
Minimize prediction error:
L = ||ẑ_f - z_f||²
V-JEPA 2 uses future-biased masking:
Context Frames:
- Keep early frames (e.g., frames 1-8 out of 16)
- Provides past context
- ~50% of sequence
Target Frames:
- Predict later frames (e.g., frames 9-16)
- Forces temporal prediction
- ~50% of sequence
Alternative strategies:
- Random frame dropout
- Periodic frame sampling
- Multi-rate prediction (predict 1, 2, 4 steps ahead)
Same as I-JEPA but for video:
θ_target ← τ · θ_target + (1 - τ) · θ_encoder
Where:
- τ = 0.996: EMA decay
- Prevents temporal collapse
- Stable targets for video sequences
- Smoother dynamics learning
Think of V-JEPA 2 as learning "physics intuition" from videos:
-
Watch the Past:
- Observe first half of video
- Build understanding of scene dynamics
- Example: ball is moving left, person is walking
-
Predict the Future:
- Predict what happens next (in feature space)
- Not pixel-by-pixel, but semantic features
- Example: "ball continues left", "person keeps walking"
-
Learn Dynamics:
- Understand motion patterns
- Learn object trajectories
- Model temporal coherence
-
No Actions Needed:
- Learns from passive observation
- Discovers natural dynamics
- No need for action labels
Key Insight: By predicting future in representation space, V-JEPA 2 learns meaningful temporal dynamics without pixel-level details or action labels.
V-JEPA 2 models change over time:
- Predicts future frames from past frames
- Learns motion patterns and object dynamics
- Captures temporal dependencies
- Understands causality
Like modern world models, V-JEPA 2 operates in latent space:
- More efficient than pixel prediction
- More semantic (learns "what changes" not "pixel values")
- Better generalization
- Focuses on important features
V-JEPA 2 learns dynamics from observation alone:
- No action labels required
- No reward signals needed
- Learns natural video statistics
- Scalable to internet-scale data
This enables:
- Training on vast internet video data
- Zero-shot transfer to new domains
- Robotics applications with passive observation
Learns both spatial and temporal structure:
- Spatial: objects, scenes, layouts
- Temporal: motion, change, dynamics
- Spatiotemporal: how objects move through space
Video Encoder (Spatiotemporal Transformer):
- Input: Video patches (B, T, H, W, C)
- Patch embedding: 16x16 spatial, 2 temporal
- Factorized space-time attention
- Layers: 24
- Hidden dim: 1024
- Heads: 16
Factorized Attention:
# Spatial attention block
x = spatial_attention(x) # Attend within frames
# Temporal attention block
x = temporal_attention(x) # Attend across frames
# Alternate spatial and temporal blocksPredictor:
- Input: Context representations + future time positions
- Architecture: Lightweight transformer
- Layers: 8
- Hidden dim: 512
- Predicts future frame representations
Target Encoder (EMA):
- Same architecture as context encoder
- Updated via EMA (τ = 0.996)
- No gradients during forward pass
# Pseudo-code for V-JEPA 2 training
# Initialize
context_encoder = SpatiotemporalTransformer(config)
target_encoder = SpatiotemporalTransformer(config)
predictor = TemporalPredictor(config)
# Copy initial weights
target_encoder.load_state_dict(context_encoder.state_dict())
target_encoder.requires_grad_(False)
for batch in dataloader:
videos = batch # (B, T, C, H, W)
# 1. Sample temporal mask
context_frames, target_frames = sample_temporal_mask(videos)
# context_frames: (B, T_c, C, H, W)
# target_frames: (B, T_t, C, H, W)
# 2. Encode context
z_context = context_encoder(context_frames) # (B, T_c, D)
# 3. Predict target representations
z_pred = predictor(z_context, target_time_steps) # (B, T_t, D)
# 4. Encode targets (no gradient)
with torch.no_grad():
z_target = target_encoder(target_frames) # (B, T_t, D)
# 5. Compute loss
loss = F.mse_loss(z_pred, z_target)
# 6. Update context encoder and predictor
loss.backward()
optimizer.step()
optimizer.zero_grad()
# 7. Update target encoder (EMA)
update_ema(target_encoder, context_encoder, tau=0.996)V-JEPA 2 implementation reference:
# Conceptual API
from nexus.models.ssl import VJEPAModel
config = {
"num_frames": 16,
"frame_size": 224,
"patch_size": 16,
"temporal_patch_size": 2,
"encoder_dim": 1024,
"encoder_depth": 24,
"encoder_heads": 16,
"predictor_dim": 512,
"predictor_depth": 8,
"ema_decay": 0.996,
}
model = VJEPAModel(config)
# Training
for videos in dataloader:
loss, metrics = model(videos)
loss.backward()
optimizer.step()
# Extract features for downstream tasks
video = load_video("robot_demo.mp4")
features = model.encode(video) # (T, D)Future Prediction:
def future_prediction_mask(num_frames=16, context_ratio=0.5):
"""Keep first half, predict second half."""
split_point = int(num_frames * context_ratio)
context_frames = list(range(split_point))
target_frames = list(range(split_point, num_frames))
return context_frames, target_framesRandom Frame Dropout:
def random_frame_mask(num_frames=16, mask_ratio=0.5):
"""Randomly mask frames."""
num_masked = int(num_frames * mask_ratio)
masked_indices = random.sample(range(num_frames), num_masked)
context_frames = [i for i in range(num_frames) if i not in masked_indices]
target_frames = masked_indices
return context_frames, target_framesMulti-Rate Prediction:
def multi_rate_mask(num_frames=16):
"""Predict multiple future horizons."""
context = list(range(8)) # First 8 frames
# Predict frames 9, 11, 13, 15 (skip frames)
targets = [8, 10, 12, 14]
return context, targetsclass SpatiotemporalBlock(nn.Module):
def __init__(self, dim, num_heads):
super().__init__()
self.spatial_attn = Attention(dim, num_heads)
self.temporal_attn = Attention(dim, num_heads)
self.mlp = MLP(dim)
def forward(self, x):
# x: (B, T, H*W, D)
B, T, N, D = x.shape
# Spatial attention (within frames)
x = rearrange(x, 'b t n d -> (b t) n d')
x = x + self.spatial_attn(x)
x = rearrange(x, '(b t) n d -> b t n d', b=B, t=T)
# Temporal attention (across frames)
x = rearrange(x, 'b t n d -> (b n) t d')
x = x + self.temporal_attn(x)
x = rearrange(x, '(b n) t d -> b t n d', b=B, n=N)
# MLP
x = x + self.mlp(x)
return x# Accumulate gradients for longer sequences
accumulation_steps = 4
for i, batch in enumerate(dataloader):
loss, _ = model(batch)
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for videos in dataloader:
with autocast():
loss, metrics = model(videos)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()def get_num_frames(epoch, max_frames=16):
"""Gradually increase sequence length."""
if epoch < 10:
return 8 # Start with 8 frames
elif epoch < 20:
return 12 # Increase to 12
else:
return max_frames # Full 16 frames| Method | Kinetics-400 | Something-Something-v2 |
|---|---|---|
| Supervised ViT | 81.2% | 68.4% |
| VideoMAE | 83.1% | 71.2% |
| V-JEPA 2 | 85.3% | 74.8% |
Strong performance on temporal reasoning!
Zero-shot transfer to robot manipulation:
| Pre-training | Success Rate (Pick-Place) |
|---|---|
| Random init | 12% |
| ImageNet | 45% |
| I-JEPA | 58% |
| V-JEPA 2 | 72% |
Temporal understanding helps robot control!
Train on 10% of Kinetics:
| Method | 10% Kinetics Accuracy |
|---|---|
| Supervised | 42.3% |
| Contrastive | 58.7% |
| VideoMAE | 65.2% |
| V-JEPA 2 | 71.5% |
Excellent data efficiency!
Future frame prediction accuracy:
| Horizon | V-JEPA 2 MSE | VideoMAE MSE |
|---|---|---|
| 1 frame | 0.012 | 0.018 |
| 2 frames | 0.024 | 0.041 |
| 4 frames | 0.051 | 0.093 |
| 8 frames | 0.098 | 0.187 |
Better long-term prediction!
| Method | GPU Hours (Kinetics) |
|---|---|
| Supervised | 3072 |
| VideoMAE | 4096 |
| V-JEPA 2 | 3200 |
Efficient training on videos.
| Aspect | V-JEPA 2 | DreamerV3 |
|---|---|---|
| Actions | ❌ Action-free | ✅ Action-conditioned |
| Rewards | ❌ No rewards | ✅ Reward prediction |
| Use Case | Representation learning | RL with planning |
| Training Data | Any videos | RL episodes |
| Output | Representations | Pixels + rewards |
V-JEPA 2 = pre-training for DreamerV3!
| Aspect | V-JEPA 2 | I-JEPA |
|---|---|---|
| Temporal | ✅ Video | ❌ Single image |
| Dynamics | ✅ Models change | ❌ Static |
| Use Case | Temporal modeling | Spatial modeling |
| Architecture | Spatiotemporal | Spatial-only |
V-JEPA 2 extends I-JEPA to time!
| Aspect | V-JEPA 2 | Genie |
|---|---|---|
| Actions | ❌ | ✅ Latent actions |
| Output | Representations | Pixels (playable) |
| Goal | Learn dynamics | Generate worlds |
| Interactive | ❌ | ✅ |
V-JEPA 2 = representation learning, Genie = world generation.
V-JEPA 2 excels at:
- Action recognition
- Video classification
- Temporal reasoning
- Event detection
Example:
# Pre-train V-JEPA on videos
vjepa = VJEPAModel(config)
vjepa.pretrain(youtube_videos)
# Fine-tune for action recognition
classifier = nn.Linear(vjepa.encoder_dim, num_classes)
for videos, labels in dataloader:
features = vjepa.encode(videos)
logits = classifier(features.mean(dim=1))
loss = F.cross_entropy(logits, labels)Zero-shot transfer to robot control:
# Pre-train V-JEPA on human videos
vjepa = VJEPAModel(config)
vjepa.pretrain(youtube_videos)
# Use for robot control
robot_obs = camera.get_video_sequence()
robot_features = vjepa.encode(robot_obs)
action = policy(robot_features[-1]) # Use latest frame featuresPredict what happens next:
def predict_future(vjepa, past_frames, num_future_frames=4):
"""Predict future frame representations."""
# Encode past
z_past = vjepa.encode(past_frames)
# Predict future
z_future = []
for t in range(num_future_frames):
z_t = vjepa.predict_next(z_past)
z_future.append(z_t)
z_past = torch.cat([z_past[1:], z_t.unsqueeze(0)], dim=0)
return z_futureUse V-JEPA 2 as foundation for model-based RL:
# Pre-train on videos
vjepa.pretrain(video_data)
# Add action conditioning
world_model = ActionConditionedVJEPA(vjepa)
world_model.add_action_input()
# Train RL agent
agent.learn_with_world_model(world_model)Detect unusual events:
def detect_anomaly(vjepa, video):
"""Detect anomalies by prediction error."""
past = video[:8]
future = video[8:]
# Encode and predict
z_past = vjepa.encode(past)
z_future_pred = vjepa.predict_future(z_past, len(future))
z_future_actual = vjepa.encode(future)
# Compute prediction error
error = F.mse_loss(z_future_pred, z_future_actual)
# High error = anomaly
is_anomaly = error > threshold
return is_anomaly, errorProblem: Model predicts static features, ignores motion
Symptoms:
- Same prediction for all future frames
- No temporal variation
- Poor video classification
Solutions:
# Temporal diversity loss
def temporal_diversity_loss(features):
# features: (B, T, D)
variance = features.var(dim=1).mean()
return -variance # Maximize temporal variance
loss = prediction_loss + 0.1 * temporal_diversity_loss(predicted_features)
# Temporal contrastive loss
# Positive: same video, negative: different videosProblem: OOM for long videos
Symptoms:
- CUDA out of memory
- Can't process full videos
Solutions:
# Gradient checkpointing
from torch.utils.checkpoint import checkpoint
def forward_with_checkpointing(x):
return checkpoint(model, x)
# Process in chunks
def encode_long_video(video, chunk_size=8):
features = []
for i in range(0, len(video), chunk_size):
chunk = video[i:i+chunk_size]
with torch.no_grad():
feat = model.encode(chunk)
features.append(feat)
return torch.cat(features, dim=0)Problem: Encoder doesn't align frames temporally
Symptoms:
- Poor future prediction
- Temporal order confusion
Solutions:
# Strong temporal positional encoding
pos_embed = nn.Parameter(torch.randn(1, max_frames, dim))
x = x + pos_embed[:, :T]
# Relative positional encoding
# Learn offsets between frames, not absolute positionsProblem: Model learns static appearance, not dynamics
Symptoms:
- Good on static videos, poor on dynamic
- Ignores motion
Solutions:
# Data augmentation: temporal jittering
def temporal_jitter(video, max_jitter=2):
"""Randomly shift frame indices."""
indices = torch.arange(len(video))
jitter = torch.randint(-max_jitter, max_jitter+1, (len(video),))
indices = torch.clamp(indices + jitter, 0, len(video)-1)
return video[indices]
# Motion-based sampling
# Prefer videos with more motion during trainingProblem: Videos have different speeds/scales
Symptoms:
- Poor generalization across datasets
- Speed-dependent representations
Solutions:
# Temporal rescaling
def temporal_rescale(video, target_frames=16):
"""Resample video to target length."""
return F.interpolate(video, size=target_frames, mode='linear')
# Multi-scale temporal modeling
# Process at different frame rates: 1fps, 2fps, 4fpsTrain on internet videos:
- Millions of hours of data
- Diverse dynamics
- Rich temporal patterns
- YouTube, movies, robot demos
Pre-trained representations transfer to:
- New domains (different visual styles)
- New tasks (classification, control)
- New modalities (sim-to-real)
- Different speeds (slow-mo to time-lapse)
Operates in compact latent space:
- 1024-dim vectors (not 224×224×3 pixels)
- Fast dynamics modeling
- Efficient planning
- Scalable to long videos
Unlike DreamerV3, V-JEPA 2 doesn't need:
- Explicit action labels
- Reward signals
- RL environment interaction
Can learn from passive observation!
Cannot model:
- Action-conditioned dynamics: s_t+1 = f(s_t, a_t)
- Policy learning
- Interactive control
Solution: Add action conditioning (future work) or use DreamerV3.
Cannot:
- Predict task-relevant outcomes
- Train RL policies directly
- Optimize for goals
Solution: Add reward prediction head for RL applications.
Predicts single future (no uncertainty):
- Can't model stochastic environments
- No distribution over futures
Solution: Add stochastic latent variables (like DreamerV3's RSSM).
Doesn't generate pixels:
- Can't visualize predictions
- Need decoder for interpretability
Solution: Add optional decoder for visualization (like VideoMAE).
@article{bardes2024vjepa,
title={Revisiting Feature Prediction for Learning Visual Representations from Video},
author={Bardes, Adrien and Garrido, Quentin and Ponce, Jean and Rabbat, Michael and LeCun, Yann and Assran, Mahmoud and Ballas, Nicolas},
journal={arXiv preprint arXiv:2404.08471},
year={2024}
}Official Code: https://github.com/facebookresearch/jepa Paper: https://arxiv.org/abs/2404.08471
For complete V-JEPA 2 documentation, see:
This includes:
- Detailed architecture
- Training procedures
- Code walkthroughs
- Optimization tricks
- Experimental results
V-JEPA 2 as a world model:
- ✅ Models temporal dynamics
- ✅ Learns from video observations
- ✅ Scalable pre-training
- ✅ Zero-shot transfer
- ✅ Efficient latent space
- ✅ Spatiotemporal understanding
- ❌ No action conditioning
- ❌ No reward prediction
- ❌ Deterministic predictions
Use V-JEPA 2 when:
- Learning from passive video observation
- Robotics with observation-only pre-training
- Video understanding tasks
- Transfer learning for dynamics
- Don't have action labels
- Need temporal representations
Upgrade to DreamerV3 when:
- Need action-conditioned dynamics
- Training RL agents
- Interactive environments
- Reward-driven learning
- Stochastic environments
V-JEPA 2 as Foundation:
- Excellent pre-training for visual RL
- Strong temporal representations
- Can add action/reward heads on top
- Combines SSL and world modeling
- Bridge between vision and control
Key Takeaways:
- Temporal prediction in representation space
- Factorized spatiotemporal attention
- Future-biased masking for dynamics
- EMA prevents temporal collapse
- Strong transfer to downstream tasks