LLaVA-NeXT represents the next generation of the groundbreaking LLaVA (Large Language and Vision Assistant) architecture, addressing critical limitations of the original model while maintaining its open-source accessibility. Building on LLaVA's success as one of the first practical vision-language instruction-following models, LLaVA-NeXT introduces substantial improvements in visual understanding, particularly for high-resolution images and complex spatial reasoning tasks.
Key Achievement: First open-source multimodal model to match GPT-4V performance on many benchmarks while being fully transparent and reproducible.
Evolution Timeline:
- LLaVA 1.0 (April 2023): Pioneered visual instruction tuning with simple projector
- LLaVA 1.5 (October 2023): Added MLP projector, improved data quality
- LLaVA-NeXT (January 2024): Dynamic high-resolution, enhanced spatial reasoning
- LLaVA-OneVision (August 2024): Unified image and video understanding
Why LLaVA-NeXT Matters:
-
Open Science: Unlike proprietary models (GPT-4V, Gemini Vision), LLaVA-NeXT provides full access to:
- Model architecture and weights
- Training data and recipes
- Evaluation methodologies
-
Dynamic Resolution: Processes images at up to 4x higher resolution (336×336 → 672×672 or beyond) while maintaining efficiency through adaptive patch allocation
-
Enhanced Spatial Reasoning: Significantly improved performance on:
- Text-rich images and OCR tasks
- Charts and diagrams
- Fine-grained visual details
- Multi-image reasoning
-
Practical Performance: Achieves competitive results with much larger proprietary models, making advanced multimodal AI accessible to researchers and developers
Performance Highlights:
- MMBench: 70.5% accuracy (vs. 65.6% for LLaVA-1.5)
- TextVQA: 67.1% accuracy (21% relative improvement)
- ChartQA: 62.3% accuracy (substantial improvement on visual reasoning)
LLaVA-NeXT builds on the visual instruction tuning paradigm introduced by the original LLaVA, which demonstrates that language models can be effectively adapted to multimodal tasks through instruction following.
Core Hypothesis: A language model pretrained on text can be efficiently extended to vision-language tasks by:
- Aligning visual features to the language model's input space
- Fine-tuning on high-quality instruction-following data
- Preserving the language model's reasoning capabilities
The fundamental innovation in LLaVA-NeXT is its any-resolution processing capability, which addresses a critical limitation of fixed-resolution approaches.
Problem with Fixed Resolution:
Original image: 1920×1080 (Full HD)
Downsampled to: 336×336
Information loss: ~95% of pixels discarded
Result: Poor text recognition, missed details
LLaVA-NeXT Solution:
1. Analyze aspect ratio and content requirements
2. Divide image into NxM grid of patches
3. Process each patch at base resolution (336×336)
4. Combine global view with detailed patches
Advantages:
- Preserves fine-grained details for OCR and text-rich images
- Adapts to image aspect ratio (avoids distortion)
- Balances computational cost with visual fidelity
- Scales to very high resolutions when needed
LLaVA-NeXT employs a carefully designed three-stage training curriculum:
Stage 1: Visual-Language Alignment (Pretraining)
- Data: Image-caption pairs (e.g., CC3M, LAION)
- Objective: Align vision encoder outputs with language model input space
- Frozen: Vision encoder, language model
- Trainable: Only the projection layer
- Duration: ~1 epoch, 500K-1M samples
Stage 2: Visual Instruction Tuning
- Data: LLaVA-Instruct-158K (GPT-4 generated instructions)
- Objective: Teach instruction following and visual reasoning
- Frozen: Vision encoder
- Trainable: Projection layer, language model
- Duration: 1 epoch
Stage 3: High-Resolution Fine-Tuning
- Data: Mix of high-resolution datasets (TextVQA, DocVQA, ChartQA)
- Objective: Enhance detail perception and complex reasoning
- Trainable: Full model with any-resolution processing
- Duration: 1 epoch on curated high-quality data
1. Dynamic Patch Allocation
Instead of fixed grid splitting, LLaVA-NeXT dynamically determines the optimal patch configuration:
Allocation Strategy:
- Input: Image dimensions (H, W), max_patches
- Compute aspect ratio: r = W / H
- Find grid (n, m) that minimizes:
* Distortion: |n/m - r|
* Token count: n × m ≤ max_patches
* Coverage: Maximize utilization of max_patches
2. Improved Visual-Language Connector
Evolution from simple linear projection to sophisticated multi-layer projector:
LLaVA 1.0: Linear projection
visual_dim → language_dim
LLaVA 1.5: MLP projector
visual_dim → hidden_dim → language_dim
+ GELU activation
LLaVA-NeXT: Any-resolution projector
[global_view, patch_1, ..., patch_n] →
spatial_pooling → MLP → language_space
3. Position-Aware Encoding
Preserves spatial relationships in multi-patch processing:
- Each patch retains its grid position (i, j)
- Learnable 2D positional embeddings
- Enables spatial reasoning across patches
Primary Loss: Autoregressive Language Modeling
L_language = -Σ log P(y_t | x_visual, y_<t)
where:
y_t = target token at position t
x_visual = visual features
y_<t = previous tokens
Visual-Language Alignment Loss (Stage 1 only):
L_align = MSE(proj(v), language_embed(caption))
Encourages visual features to align with corresponding text embeddings
Instruction-Following Loss (Stages 2-3):
L_instruct = -Σ log P(answer | instruction, image)
Computed only on answer tokens, not instruction tokens
LLaVA-NeXT's success relies heavily on curated high-quality training data:
Data Sources:
- Captioning: CC3M, SBU, LAION-400M (filtered)
- VQA: VQAv2, GQA, OKVQA, A-OKVQA
- Text-Rich: TextVQA, DocVQA, InfographicVQA
- Charts/Diagrams: ChartQA, PlotQA, FigureQA
- Instruction: LLaVA-Instruct (GPT-4 generated)
- Reasoning: VizWiz, Visual Genome
Data Filtering:
- Remove low-quality images (resolution, blur, artifacts)
- Filter out toxic or inappropriate content
- Balance dataset composition
- Deduplicate similar examples
1. Vision Encoding
Given input image
Step 1: Dynamic Resolution Processing
Determine grid configuration (n, m):
n, m = argmin_{(i,j): i×j ≤ max_patches} distortion(i, j, H/W)
Partition image:
I_patches = {I_{ij} | i ∈ [1,n], j ∈ [1,m]}
where each I_{ij} ∈ ℝ^{336 × 336 × 3}
Step 2: Patch Encoding
For each patch I_{ij}:
v_{ij} = ViT_encoder(I_{ij})
where v_{ij} ∈ ℝ^{N_p × d_v}
N_p = (336/patch_size)² = number of visual tokens per patch
d_v = vision encoder dimension (typically 1024)
Step 3: Global View Encoding
I_global = resize(I, 336, 336)
v_global = ViT_encoder(I_global)
Step 4: Feature Aggregation
V = concat([v_global, v_{11}, v_{12}, ..., v_{nm}])
V ∈ ℝ^{(1+n×m)×N_p × d_v}
2. Visual-Language Projection
Transform visual features to language model input space:
Step 1: Spatial pooling (optional)
V' = pool(V, strategy='adaptive')
Reduces sequence length while preserving information
Step 2: Multi-layer projection
H_0 = V'
H_l = GELU(W_l H_{l-1} + b_l) for l = 1, ..., L-1
Z_visual = W_L H_{L-1} + b_L
where Z_visual ∈ ℝ^{N_visual × d_lang}
d_lang = language model dimension (typically 4096)
3. Multi-Modal Fusion
Combine visual and text tokens:
Tokenize text instruction:
text = "Describe this image in detail."
T = tokenizer(text)
T_embed = embedding_layer(T) ∈ ℝ^{N_text × d_lang}
Interleave modalities:
X = [BOS, Z_visual, T_embed, EOS]
X ∈ ℝ^{(2 + N_visual + N_text) × d_lang}
Create attention mask:
M = [1, 1, ..., 1] (all ones, causal masking handled by LLM)
4. Language Model Generation
For t = 1 to max_length:
# Get hidden states
H_t = LLM(X, position_ids, attention_mask=M)
# Predict next token
logits_t = H_t[-1] @ W_vocab^T ∈ ℝ^{vocab_size}
# Sample or greedy decode
token_t = sample(logits_t, temperature, top_p)
# Append to sequence
X = concat([X, embedding(token_t)])
Training Loss (Instruction Tuning):
L_total = L_visual + L_text
where:
L_visual = -Σ_{t=1}^{T_visual} log P(y_t^visual | x_visual, y_{<t})
L_text = -Σ_{t=1}^{T_answer} log P(y_t^answer | x_visual, instruction, y_{<t})
Weighting Strategy:
L = α · L_visual + β · L_text
Typical values:
α = 0.0 (often ignore in later stages)
β = 1.0 (focus on answer quality)
Optimization Problem:
Given:
H, W = image dimensions
max_patches = computational budget
base_size = 336 (base patch resolution)
Find:
(n, m) = optimal grid configuration
Objective:
minimize: |aspect_ratio_grid - aspect_ratio_image| + λ · waste
where:
aspect_ratio_grid = m / n
aspect_ratio_image = W / H
waste = max_patches - n × m
λ = trade-off parameter (typically 0.1)
Constraints:
n, m ≥ 1
n × m ≤ max_patches
Efficient Algorithm:
def select_grid(H, W, max_patches=5):
target_ratio = W / H
best_config = (1, 1)
best_score = float('inf')
for n in range(1, int(sqrt(max_patches)) + 1):
for m in range(1, max_patches // n + 1):
if n * m > max_patches:
continue
grid_ratio = m / n
distortion = abs(grid_ratio - target_ratio)
waste = (max_patches - n * m) / max_patches
score = distortion + 0.1 * waste
if score < best_score:
best_score = score
best_config = (n, m)
return best_configEach patch gets 2D positional encoding:
For patch at position (i, j):
PE_{ij} = LearnableEmbed2D[i, j] ∈ ℝ^{d_lang}
Add to patch features:
Z_{ij} = Z_{ij} + PE_{ij}
Final sequence:
Z_final = [Z_global, Z_{11}+PE_{11}, Z_{12}+PE_{12}, ..., Z_{nm}+PE_{nm}]
Think of LLaVA-NeXT's any-resolution processing like a smart magnifying glass system:
Traditional Fixed-Resolution (LLaVA 1.0):
- Like viewing everything through the same lens
- Shrink large images to fit → lose details
- Same magnification for postage stamp and billboard
- Fast but loses fine details
LLaVA-NeXT Any-Resolution:
- Like having multiple magnifying glasses
- Global view: See the whole scene at once
- Detail patches: Examine important regions closely
- Adaptive: More patches for complex images, fewer for simple ones
- Balance: See both forest and trees
Example Scenario:
Reading a restaurant menu in an image:
Fixed Resolution:
[Shrink entire menu to 336×336]
→ Text becomes blurry and unreadable
→ Model can't read menu items
Any-Resolution:
[Keep 3×2 grid of patches at 336×336 each]
[Process global view for layout]
[Process each patch for text details]
→ Clear text in each region
→ Model can read and understand menu
The Teaching Process:
- Show: Present image with rich visual details
- Ask: Provide clear instruction ("Describe the main object")
- Guide: Use high-quality examples from GPT-4
- Reinforce: Train with diverse question-answer pairs
- Evaluate: Test on held-out tasks
Key Insight: The model learns not just to describe images, but to follow instructions about images, enabling:
- Answering specific questions
- Focusing on relevant details
- Reasoning about visual content
- Combining visual and textual information
Input Image (1920×1080)
↓
[Dynamic Resolution Analyzer]
↓
Grid: 3×2
↓
┌─────────────────────────────┐
│ Global View (336×336) │
│ Patch 1,1 Patch 1,2 │
│ Patch 2,1 Patch 2,2 │
│ Patch 3,1 Patch 3,2 │
└─────────────────────────────┘
↓
[Vision Encoder (ViT)]
↓
Visual Features (7 × 576 × 1024)
↓
[Any-Resolution Projector]
↓
Language Space (7 × 576 × 4096)
↓
[Position-Aware Concatenation]
↓
Visual Tokens (4032 × 4096)
↓
[Merge with Text Tokens]
↓
Combined Sequence
↓
[Language Model (Llama, Vicuna, etc.)]
↓
Generated Response
The implementation in nexus/models/multimodal/llava_next.py consists of three main components:
1. Dynamic Image Processor
class DynamicImageProcessor(NexusModule):
"""Handles variable-resolution image encoding."""
def __init__(
self,
base_size: int = 336, # Base resolution for patches
max_patches: int = 5, # Maximum number of patches
patch_size: int = 14, # ViT patch size
hidden_dim: int = 1024, # Vision encoder dimension
):
# Patch embedding layer (convolutional)
self.patch_embed = nn.Conv2d(
3, hidden_dim,
kernel_size=patch_size,
stride=patch_size
)
# Learnable positional embeddings
self.pos_embed = nn.Parameter(
torch.randn(1, (base_size // patch_size) ** 2, hidden_dim)
)Key Features:
- Convolutional patch embedding (efficient than unfolding)
- Interpolated positional embeddings for variable sizes
- Handles non-square images naturally
2. Any-Resolution Projector
class AnyResolutionProjector(NexusModule):
"""Projects multi-resolution visual features to language space."""
def __init__(
self,
visual_dim: int = 1024,
text_dim: int = 4096,
num_layers: int = 2,
):
# Multi-layer MLP with GELU activations
layers = []
for i in range(num_layers):
in_dim = visual_dim if i == 0 else text_dim
layers.extend([
nn.Linear(in_dim, text_dim),
nn.GELU()
])
# Remove last activation
layers = layers[:-1]
self.projector = nn.Sequential(*layers)Design Choices:
- 2-layer MLP (balance between capacity and efficiency)
- GELU activation (better than ReLU for this task)
- No activation on final layer (preserve signal magnitude)
- Residual connections optional but not standard
3. Main LLaVA-NeXT Module
class LLaVANeXT(NexusModule):
"""Main multimodal model."""
def __init__(
self,
visual_encoder_dim: int = 1024,
language_model_dim: int = 4096,
num_visual_tokens: int = 576,
max_images: int = 8,
projector_layers: int = 2,
use_video: bool = False,
):
# Components
self.image_processor = DynamicImageProcessor(...)
self.projector = AnyResolutionProjector(...)
# Optional video encoder
if use_video:
self.temporal_encoder = nn.TransformerEncoder(...)
# Special tokens
self.image_start_token = nn.Parameter(...)
self.image_end_token = nn.Parameter(...)Basic Single-Image Inference:
from nexus.models.multimodal import LLaVANeXT
# Initialize model
model = LLaVANeXT(
visual_encoder_dim=1024,
language_model_dim=4096,
num_visual_tokens=576,
max_images=8
)
# Prepare inputs
image = load_image("photo.jpg") # [3, H, W]
image = image.unsqueeze(0) # [1, 3, H, W]
# Get visual features
output = model(images=image)
visual_features = output['visual_features']
# Shape: [1, num_patches * 576, 4096]
# Combine with text for generation
text = "Describe this image in detail."
text_embeds = tokenize_and_embed(text) # [1, seq_len, 4096]
output = model(images=image, text_embeds=text_embeds)
multimodal_embeds = output['multimodal_embeds']
# Feed to language model for generationMulti-Image Processing:
# Multiple images in one sample
images = [
load_image("image1.jpg"),
load_image("image2.jpg"),
load_image("image3.jpg")
]
images = torch.stack(images) # [3, 3, H, W]
# Specify how many images per sample in batch
output = model(
images=images,
text_embeds=text_embeds,
num_images_per_sample=[3] # One sample with 3 images
)Video Understanding (LLaVA-OneVision):
from nexus.models.multimodal import LLaVAOneVision
model = LLaVAOneVision(
visual_encoder_dim=1024,
language_model_dim=4096,
max_frames=32
)
# Load video frames
video = load_video("clip.mp4") # [1, num_frames, 3, H, W]
# Encode with temporal modeling
output = model(video_frames=video, text_embeds=text_embeds)
video_features = output['visual_features']
# Temporal encoder models relationships between framesRecommended Hyperparameters:
# Stage 1: Visual-language alignment
config_stage1 = {
'learning_rate': 1e-3,
'batch_size': 256,
'epochs': 1,
'warmup_steps': 2000,
'weight_decay': 0.0,
'freeze_vision_encoder': True,
'freeze_language_model': True,
'trainable': ['projector']
}
# Stage 2: Instruction tuning
config_stage2 = {
'learning_rate': 2e-5,
'batch_size': 128,
'epochs': 1,
'warmup_ratio': 0.03,
'weight_decay': 0.0,
'freeze_vision_encoder': True,
'trainable': ['projector', 'language_model']
}
# Stage 3: High-resolution fine-tuning
config_stage3 = {
'learning_rate': 2e-5,
'batch_size': 64, # Smaller due to higher resolution
'epochs': 1,
'gradient_accumulation': 2,
'max_patches': 5, # Enable any-resolution
'trainable': ['all']
}Memory Optimization:
# Gradient checkpointing
model.language_model.gradient_checkpointing_enable()
# Mixed precision training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
with autocast():
output = model(images, text_embeds)
loss = compute_loss(output, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()Let's trace through how an image is processed:
def encode_images(
self,
images: torch.Tensor,
num_images_per_sample: Optional[List[int]] = None
) -> torch.Tensor:
"""
Args:
images: [total_images, 3, H, W]
num_images_per_sample: Number of images per batch sample
Returns:
Encoded features: [batch_size, num_visual_tokens, language_model_dim]
"""
# Step 1: Process images through vision encoder
# This applies patch embedding and vision transformer
visual_features = self.image_processor(images)
# Shape: [total_images, num_patches, visual_encoder_dim]
# Example: [5, 576, 1024] for 5 images
# Step 2: Project to language model space
# Multi-layer MLP with GELU
projected_features = self.projector(visual_features)
# Shape: [total_images, num_patches, language_model_dim]
# Example: [5, 576, 4096]
# Step 3: Reshape if multiple images per sample
if num_images_per_sample is not None:
batch_outputs = []
offset = 0
for num_imgs in num_images_per_sample:
# Get all images for this sample
sample_features = projected_features[offset:offset + num_imgs]
# Concatenate along sequence dimension
batch_outputs.append(sample_features.flatten(0, 1))
offset += num_imgs
projected_features = torch.stack(batch_outputs, dim=0)
# Shape: [batch_size, num_imgs * num_patches, language_model_dim]
return projected_featuresKey Operations:
- Patch Embedding: Convert image pixels to patch tokens
- Vision Transformer: Apply self-attention across patches
- Projection: Map visual features to language space
- Concatenation: Merge multiple images if needed
def forward(
self,
images: Optional[torch.Tensor] = None,
text_embeds: Optional[torch.Tensor] = None,
video_frames: Optional[torch.Tensor] = None,
num_images_per_sample: Optional[List[int]] = None,
attention_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
"""Complete forward pass."""
outputs = {}
# Encode visual input
if video_frames is not None:
# Video path: encode frames + temporal modeling
visual_features = self.encode_video(video_frames)
elif images is not None:
# Image path: encode images
visual_features = self.encode_images(images, num_images_per_sample)
else:
raise ValueError("Either images or video_frames required")
outputs['visual_features'] = visual_features
# Combine with text if provided
if text_embeds is not None:
B = text_embeds.shape[0]
# Add special boundary tokens
image_start = self.image_start_token.expand(B, -1, -1)
image_end = self.image_end_token.expand(B, -1, -1)
# Concatenate: [START] [VISUAL] [END] [TEXT]
multimodal_embeds = torch.cat([
image_start,
visual_features,
image_end,
text_embeds
], dim=1)
outputs['multimodal_embeds'] = multimodal_embeds
# Extend attention mask to cover visual tokens
if attention_mask is not None:
visual_mask = torch.ones(
B, visual_features.shape[1] + 2,
device=attention_mask.device,
dtype=attention_mask.dtype
)
extended_mask = torch.cat([visual_mask, attention_mask], dim=1)
outputs['attention_mask'] = extended_mask
return outputsFlow:
- Encode visual input (images or video)
- Add special tokens to mark visual content boundaries
- Concatenate with text embeddings
- Extend attention mask to include visual tokens
- Return combined multimodal embedding
def encode_video(self, video_frames: torch.Tensor) -> torch.Tensor:
"""
Args:
video_frames: [batch_size, num_frames, 3, H, W]
Returns:
Video features: [batch_size, num_visual_tokens, language_model_dim]
"""
if not self.use_video:
raise ValueError("Video encoding not enabled")
B, T = video_frames.shape[:2]
# Step 1: Encode each frame independently
frames = video_frames.reshape(B * T, *video_frames.shape[2:])
frame_features = self.encode_images(frames)
# Shape: [B*T, num_patches, language_model_dim]
# Step 2: Reshape to separate batch and time
frame_features = frame_features.reshape(
B, T, -1, self.language_model_dim
)
# Step 3: Flatten spatial and temporal dimensions
frame_features = frame_features.flatten(1, 2)
# Shape: [B, T * num_patches, language_model_dim]
# Step 4: Apply temporal encoder (cross-frame attention)
video_features = self.temporal_encoder(frame_features)
# Models temporal relationships between frames
return video_featuresVideo-Specific Processing:
- Treats video as sequence of images
- Encodes spatial features per frame
- Applies temporal transformer to model motion and change
- Produces unified video representation
Adaptive Resolution Selection:
def select_resolution(image, task_type):
"""Choose resolution based on task requirements."""
if task_type in ['ocr', 'text_vqa', 'document']:
# High resolution for text-heavy tasks
return 'high' # 672×672 or 4-5 patches
elif task_type in ['caption', 'vqa']:
# Medium resolution for general tasks
return 'medium' # 336×336 or 1-2 patches
elif task_type == 'classification':
# Low resolution sufficient for classification
return 'low' # 336×336, single patch
# Default: analyze image content
if has_small_text(image) or is_detailed_image(image):
return 'high'
else:
return 'medium'Trade-off Tuning:
# More patches = better quality but slower
config_high_quality = {'max_patches': 6, 'base_size': 336}
# Fewer patches = faster but may miss details
config_fast = {'max_patches': 2, 'base_size': 336}
# Balanced
config_balanced = {'max_patches': 4, 'base_size': 336}Padding Strategy:
def collate_variable_resolution(batch):
"""Efficiently batch images of different resolutions."""
# Group by similar total token count
batch_sorted = sorted(batch, key=lambda x: x['num_patches'])
# Create mini-batches of similar sizes
mini_batches = []
current_batch = []
max_tokens = 0
for sample in batch_sorted:
tokens = sample['num_patches'] * 576 # patches × tokens_per_patch
if tokens > max_tokens:
if current_batch:
mini_batches.append(current_batch)
current_batch = [sample]
max_tokens = tokens
else:
current_batch.append(sample)
if current_batch:
mini_batches.append(current_batch)
return mini_batchesBenefits:
- Reduces wasted computation on padding
- Better GPU utilization
- Faster training and inference
Gradient Checkpointing:
# Enable for vision encoder
model.image_processor.encoder.gradient_checkpointing = True
# Enable for language model
model.language_model.gradient_checkpointing_enable()
# Trade-off: 20-30% slower, but 40-50% less memorySequential Patch Processing (for very high resolution):
def process_patches_sequential(self, patches, max_memory_mb=4096):
"""Process patches sequentially to limit memory usage."""
features = []
for patch_batch in chunk(patches, chunk_size=4):
with torch.no_grad():
patch_features = self.image_processor(patch_batch)
features.append(patch_features.cpu())
torch.cuda.empty_cache()
# Concatenate on CPU, move back to GPU for projection
all_features = torch.cat(features, dim=0).cuda()
projected = self.projector(all_features)
return projectedSmart Padding:
def pad_to_aspect_ratio(image, target_ratio=1.0):
"""Pad image to target aspect ratio with minimal distortion."""
H, W = image.shape[-2:]
current_ratio = W / H
if abs(current_ratio - target_ratio) < 0.1:
return image # Already close enough
if current_ratio > target_ratio:
# Too wide, pad height
new_H = int(W / target_ratio)
pad_top = (new_H - H) // 2
pad_bottom = new_H - H - pad_top
return F.pad(image, (0, 0, pad_top, pad_bottom))
else:
# Too tall, pad width
new_W = int(H * target_ratio)
pad_left = (new_W - W) // 2
pad_right = new_W - W - pad_left
return F.pad(image, (pad_left, pad_right, 0, 0))Learning Rate Warmup:
def get_lr_schedule(optimizer, num_training_steps, warmup_ratio=0.03):
"""Linear warmup followed by cosine decay."""
num_warmup_steps = int(num_training_steps * warmup_ratio)
def lr_lambda(current_step):
if current_step < num_warmup_steps:
# Linear warmup
return current_step / max(1, num_warmup_steps)
# Cosine decay
progress = (current_step - num_warmup_steps) / \
max(1, num_training_steps - num_warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)Gradient Clipping:
# Prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)KV Cache Reuse for visual tokens:
def generate_with_cache(model, visual_features, prompt, max_length=512):
"""Reuse KV cache for visual tokens during generation."""
# Encode visual features once
with torch.no_grad():
visual_embeds = model.encode_images(visual_features)
# Initialize KV cache with visual tokens
past_key_values = model.language_model.init_cache(visual_embeds)
# Generate tokens autoregressively
text_ids = tokenizer(prompt)
for _ in range(max_length):
outputs = model.language_model(
text_ids[:, -1:], # Only new token
past_key_values=past_key_values,
use_cache=True
)
next_token = outputs.logits.argmax(-1)
text_ids = torch.cat([text_ids, next_token], dim=-1)
past_key_values = outputs.past_key_values
if next_token == tokenizer.eos_token_id:
break
return text_idsBenefits:
- Visual features computed only once
- Faster generation (no recomputation)
- Lower memory usage during generation
Visual Question Answering:
| Benchmark | LLaVA-1.5 | LLaVA-NeXT | Improvement |
|---|---|---|---|
| VQAv2 | 78.5% | 82.0% | +3.5% |
| GQA | 62.0% | 64.2% | +2.2% |
| TextVQA | 58.2% | 67.1% | +8.9% |
| VizWiz | 53.6% | 58.7% | +5.1% |
Visual Reasoning:
| Benchmark | LLaVA-1.5 | LLaVA-NeXT | Improvement |
|---|---|---|---|
| POPE (Accuracy) | 85.9% | 86.5% | +0.6% |
| MMBench | 65.6% | 70.5% | +4.9% |
| MM-Vet | 30.5% | 36.3% | +5.8% |
| LLaVA-Bench | 64.0% | 70.1% | +6.1% |
Text-Rich Images:
| Benchmark | LLaVA-1.5 | LLaVA-NeXT | Improvement |
|---|---|---|---|
| DocVQA | 43.9% | 55.8% | +11.9% |
| ChartQA | 45.2% | 62.3% | +17.1% |
| InfoVQA | 32.1% | 41.5% | +9.4% |
| OCR-VQA | 51.6% | 63.2% | +11.6% |
Key Observations:
- Massive improvements on text-rich tasks (10-17% gains)
- Significant gains on complex reasoning (5-6% on MM-Vet, LLaVA-Bench)
- Modest improvements on standard VQA (already near ceiling)
- Competitive with much larger proprietary models
| Model | Size | MMBench | TextVQA | ChartQA | Open Source |
|---|---|---|---|---|---|
| GPT-4V | Unknown | 75.1% | 78.0% | 78.5% | ✗ |
| Gemini Pro | Unknown | 73.6% | 74.6% | 74.1% | ✗ |
| LLaVA-NeXT-34B | 34B | 70.5% | 67.1% | 62.3% | ✓ |
| Qwen-VL-Chat | 9.6B | 61.8% | 63.8% | 66.3% | ✓ |
| CogVLM | 17B | 65.8% | 70.4% | 68.3% | ✓ |
Key Takeaways:
- LLaVA-NeXT approaches GPT-4V with much smaller size
- Best open-source model overall
- Transparent, reproducible, and accessible
Impact of Any-Resolution Processing:
| Configuration | TextVQA | DocVQA | ChartQA |
|---|---|---|---|
| Single 336×336 | 58.2% | 43.9% | 45.2% |
| 2×2 patches | 63.4% | 50.1% | 56.8% |
| Any-res (max 5) | 67.1% | 55.8% | 62.3% |
Improvement: +9-17% on text-heavy tasks
Number of Projection Layers:
| Layers | VQAv2 | MMBench | Speed (it/s) |
|---|---|---|---|
| 1 | 80.8% | 68.3% | 2.4 |
| 2 | 82.0% | 70.5% | 2.3 |
| 3 | 82.1% | 70.6% | 2.1 |
Conclusion: 2 layers optimal (good performance, minimal overhead)
Training Data Quality:
| Data Source | Amount | VQAv2 | MMBench |
|---|---|---|---|
| Filtered | 665K | 82.0% | 70.5% |
| + Low-quality | 1.2M | 80.2% | 67.8% |
| + Unfiltered | 2.5M | 78.5% | 64.2% |
Conclusion: Quality over quantity (careful curation crucial)
Example 1: Fine-grained OCR
Image: Restaurant menu with small text
Question: "What is the price of the Caesar Salad?"
LLaVA-1.5: "I cannot read the specific prices in this menu."
LLaVA-NeXT: "The Caesar Salad is priced at $12.95 according to the menu."
Example 2: Chart Understanding
Image: Line graph showing stock prices over time
Question: "What was the approximate stock price in March 2023?"
LLaVA-1.5: "The stock price appears to be around 150."
LLaVA-NeXT: "In March 2023, the stock price was approximately $147, showing a decline from the February peak of $165."
Example 3: Multi-Image Reasoning
Images: [Beach scene, Mountain scene, City scene]
Question: "Which location would be best for a winter vacation?"
LLaVA-NeXT: "The mountain scene would be best for a winter vacation, as it shows snow-covered peaks ideal for skiing and winter sports, unlike the beach which is more suitable for summer, or the city which is year-round."
Problem: Model sometimes generates plausible but incorrect details.
Example:
Image: Person wearing red shirt
Question: "What color is the person's hat?"
Wrong: "The person is wearing a blue hat."
Correct: "I don't see a hat in the image."
Mitigation:
- Train with negative examples (absence of objects)
- Use POPE (Polling-based Object Probing Evaluation) training data
- Add instruction: "If you're unsure, say you don't know"
- Lower temperature during generation (e.g., 0.2 instead of 0.7)
# Reduce hallucination with lower temperature
output = generate(
prompt,
temperature=0.2, # More conservative
top_p=0.9,
repetition_penalty=1.1
)Problem: High resolution causes OOM (Out of Memory) errors.
Symptoms:
RuntimeError: CUDA out of memory. Tried to allocate 2.5 GB
Solutions:
# Solution 1: Limit max patches
model = LLaVANeXT(max_images=3) # Reduce from 8
# Solution 2: Gradient checkpointing
model.gradient_checkpointing_enable()
# Solution 3: Sequential processing
images_batched = split_into_smaller_batches(images, batch_size=2)
# Solution 4: CPU offloading for inference
model = LLaVANeXT(...).cuda()
model.image_processor = model.image_processor.cpu()
# Process images on CPU, move features to GPUProblem: Loss spikes or diverges during training.
Symptoms:
Epoch 1: loss = 2.5
Epoch 2: loss = 2.3
Epoch 3: loss = 8.7 ← SPIKE
Epoch 4: loss = NaN
Causes and Fixes:
# Cause 1: Learning rate too high
# Fix: Lower LR and add warmup
optimizer = AdamW(model.parameters(), lr=2e-5) # Not 1e-4
scheduler = get_warmup_scheduler(optimizer, warmup_ratio=0.03)
# Cause 2: No gradient clipping
# Fix: Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Cause 3: Batch size too small (unstable gradients)
# Fix: Increase effective batch size via accumulation
effective_batch_size = batch_size * accumulation_steps
accumulation_steps = 4
# Cause 4: Bad data samples
# Fix: Add data filtering and validation
def validate_sample(sample):
if sample['image'] is None or sample['text'] is None:
return False
if len(sample['text']) < 5: # Too short
return False
if torch.isnan(sample['image']).any():
return False
return TrueProblem: Model doesn't effectively use multiple images.
Example:
Images: [Dog photo, Cat photo]
Question: "What animals are shown?"
Wrong: "There is a dog in the image." (only mentions first image)
Correct: "There is a dog and a cat in the images."
Fixes:
# Fix 1: Add image separator tokens
for i, img_feat in enumerate(image_features):
if i > 0:
combined.append(separator_token)
combined.append(img_feat)
# Fix 2: Train with multi-image data
# Include examples explicitly referencing multiple images
# Fix 3: Add cross-image attention
class CrossImageAttention(nn.Module):
def forward(self, image_features_list):
# Allow attention between different images
all_features = torch.cat(image_features_list, dim=1)
attended = self.attention(all_features)
return attendedProblem: Generation is too slow for production use.
Typical Speed: 1-2 tokens/second (too slow)
Optimizations:
# Optimization 1: Reduce visual token count
model = LLaVANeXT(compression_ratio=2) # Compress visual features
# Optimization 2: Use smaller backbone
# Use Vicuna-7B instead of Vicuna-13B
# Optimization 3: Quantization
from transformers import BitsAndBytesConfig
model = LLaVANeXT.from_pretrained(
"llava-next-7b",
quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
# Optimization 4: Compile with torch.compile (PyTorch 2.0+)
model = torch.compile(model, mode="reduce-overhead")
# Optimization 5: Batch processing
# Process multiple requests in parallelExpected Improvement: 3-5x faster
Problem: Model performs poorly on domain-specific images (medical, satellite, etc.)
Example:
Image: X-ray showing fracture
Question: "Describe the findings."
Wrong: "This appears to be a black and white photograph of bones."
(Misses medical context)
Solutions:
# Solution 1: Domain-specific fine-tuning
# Fine-tune on medical/scientific images with expert annotations
# Solution 2: Specialized prompting
prompt = """You are an expert radiologist. Analyze this X-ray image
and describe any abnormalities using medical terminology."""
# Solution 3: Ensemble with domain-specific models
general_output = llava_next(image, question)
specialized_output = biomedclip(image, question)
final_output = combine(general_output, specialized_output)
# Solution 4: Few-shot prompting with domain examples
prompt_with_examples = f"""
Example 1: [Medical image] → [Expert description]
Example 2: [Medical image] → [Expert description]
Now analyze: {question}
"""-
LLaVA: Visual Instruction Tuning (April 2023)
- Authors: Haotian Liu, Chunyuan Li, Qingyang Wu, Yong Jae Lee
- Link: https://arxiv.org/abs/2304.08485
- Key Contribution: First practical visual instruction tuning approach
-
Improved Baselines with Visual Instruction Tuning (October 2023)
- Authors: Haotian Liu, Chunyuan Li, Yuheng Li, Yong Jae Lee
- Link: https://arxiv.org/abs/2310.03744
- Key Contribution: LLaVA-1.5 with improved projector and data
-
LLaVA-NeXT: Improved Reasoning, OCR, and World Knowledge (January 2024)
- Blog: https://llava-vl.github.io/blog/2024-01-30-llava-next/
- Key Contribution: Any-resolution processing, enhanced spatial reasoning
-
LLaVA-OneVision: Easy Visual Task Transfer (August 2024)
- Blog: https://llava-vl.github.io/blog/2024-08-05-llava-onevision/
- Link: https://arxiv.org/abs/2408.03326
- Key Contribution: Unified image and video understanding
-
Visual Instruction Tuning with Multimodal LLMs
- Flamingo: https://arxiv.org/abs/2204.14198
- BLIP-2: https://arxiv.org/abs/2301.12597
- InstructBLIP: https://arxiv.org/abs/2305.06500
-
High-Resolution Visual Understanding
- Pix2Struct: https://arxiv.org/abs/2210.03347
- Qwen-VL: https://arxiv.org/abs/2308.12966
-
Official Repository
- GitHub: https://github.com/haotian-liu/LLaVA
- Includes training code, model weights, evaluation scripts
-
Model Weights
- Hugging Face: https://huggingface.co/liuhaotian
- Available sizes: 7B, 13B, 34B variants
-
Datasets
- LLaVA-Instruct-158K: https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K
- LLaVA-Bench: https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild
- Comprehensive Multimodal Benchmarks
-
Tutorials and Guides
- Hugging Face Integration: https://huggingface.co/docs/transformers/model_doc/llava_next
- Fine-tuning Guide: https://github.com/haotian-liu/LLaVA/blob/main/docs/Finetune_Custom_Data.md
-
Discussion and Support
- Discord: LLaVA Community
- GitHub Issues: https://github.com/haotian-liu/LLaVA/issues