|
| 1 | +# video_vae_modular_final.py |
| 2 | + |
| 3 | +# ============================================================================== |
| 4 | +# 1. IMPORTS & CONFIGURATION |
| 5 | +# ============================================================================== |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | +import torch.nn.functional as F |
| 9 | +from transformers import AutoModel, AutoProcessor, AutoModelForCausalLM |
| 10 | +from typing import List, Dict |
| 11 | +from dataclasses import dataclass |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class VideoVAEConfig: |
| 15 | + in_channels: int = 3 |
| 16 | + base_ch: int = 64 |
| 17 | + num_blocks: int = 4 |
| 18 | + quant_emb_dim: int = 16 |
| 19 | + alignment_dim: int = 256 |
| 20 | + quant_align_loss_weight: float = 0.1 |
| 21 | + likelihood_loss_weight: float = 0.2 |
| 22 | + dino_loss_weight: float = 0.25 |
| 23 | + |
| 24 | +# ============================================================================== |
| 25 | +# 2. PERCEPTUAL & TEXT MODULES |
| 26 | +# ============================================================================== |
| 27 | + |
| 28 | +class DINOv2Extractor(nn.Module): |
| 29 | + """ |
| 30 | + A frozen DINOv2 model to extract perceptual features from video frames. |
| 31 | + """ |
| 32 | + def __init__(self, device="cuda"): |
| 33 | + super().__init__() |
| 34 | + self.device = device |
| 35 | + model_name = "facebook/dinov2-base" |
| 36 | + print("Loading DINOv2 model and processor...") |
| 37 | + self.processor = AutoProcessor.from_pretrained(model_name) |
| 38 | + self.model = AutoModel.from_pretrained(model_name).to(self.device).eval() |
| 39 | + for param in self.model.parameters(): |
| 40 | + param.requires_grad = False |
| 41 | + print("DINOv2 loaded and frozen successfully. 🦖") |
| 42 | + |
| 43 | + def forward(self, video_tensor: torch.Tensor) -> torch.Tensor: |
| 44 | + b, c, t, h, w = video_tensor.shape |
| 45 | + video_tensor = video_tensor.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) |
| 46 | + inputs = self.processor(images=video_tensor, return_tensors="pt", do_rescale=False).to(self.device) |
| 47 | + with torch.no_grad(): |
| 48 | + outputs = self.model(**inputs) |
| 49 | + # Return the features of the [CLS] token |
| 50 | + return outputs.last_hidden_state[:, 0].view(b, t, -1) |
| 51 | + |
| 52 | +class QwenVLTextEncoder(nn.Module): |
| 53 | + """A frozen Qwen-VL model to extract text embeddings.""" |
| 54 | + def __init__(self, device="cuda"): |
| 55 | + super().__init__() |
| 56 | + model_id = "Qwen/Qwen2.5-VL-Instruct" |
| 57 | + self.device = device |
| 58 | + print("Loading Qwen-VL model and processor...") |
| 59 | + self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
| 60 | + self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto", trust_remote_code=True).eval() |
| 61 | + for param in self.model.parameters(): param.requires_grad = False |
| 62 | + print("Qwen-VL loaded and frozen successfully. 🥶") |
| 63 | + |
| 64 | + def forward(self, text_prompts: list[str]): |
| 65 | + messages = [[{"role": "user", "content": [{"type": "text", "text": prompt}]}] for prompt in text_prompts] |
| 66 | + text_inputs = self.processor(conversations=messages, return_tensors="pt", padding=True).to(self.model.device) |
| 67 | + with torch.no_grad(): |
| 68 | + outputs = self.model(**text_inputs, output_hidden_states=True) |
| 69 | + return outputs.hidden_states[-1].to(self.device) |
| 70 | + |
| 71 | +class TextVideoCrossAttention(nn.Module): |
| 72 | + """Performs cross-attention between video features (Q) and text features (K,V).""" |
| 73 | + def __init__(self, video_channels, text_embed_dim): |
| 74 | + super().__init__() |
| 75 | + self.q_proj, self.k_proj, self.v_proj = nn.Linear(video_channels, video_channels), nn.Linear(text_embed_dim, video_channels), nn.Linear(text_embed_dim, video_channels) |
| 76 | + self.out_proj = nn.Linear(video_channels, video_channels) |
| 77 | + |
| 78 | + def forward(self, video_feat, text_embedding): |
| 79 | + B, C, T, H, W = video_feat.shape |
| 80 | + video_seq = video_feat.permute(0, 2, 3, 4, 1).reshape(B, T * H * W, C) |
| 81 | + q, k, v = self.q_proj(video_seq), self.k_proj(text_embedding), self.v_proj(text_embedding) |
| 82 | + attn_output = F.scaled_dot_product_attention(q.unsqueeze(1), k, v).squeeze(1) |
| 83 | + return self.out_proj(attn_output).reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3) |
| 84 | + |
| 85 | +# ============================================================================== |
| 86 | +# 3. CORE ARCHITECTURAL BLOCKS |
| 87 | +# ============================================================================== |
| 88 | + |
| 89 | +class ProjectedLFQ(nn.Module): |
| 90 | + """Projects features and quantizes them, returning an entropy loss.""" |
| 91 | + def __init__(self, in_channels, quant_channels, entropy_loss_weight=0.1): |
| 92 | + super().__init__() |
| 93 | + self.project = nn.Conv3d(in_channels, quant_channels, 1) |
| 94 | + self.entropy_loss_weight = entropy_loss_weight |
| 95 | + |
| 96 | + def forward(self, x): |
| 97 | + x_proj = self.project(x) |
| 98 | + quantized_x_hard = torch.where(x_proj > 0, 1.0, -1.0) |
| 99 | + quantized_x = x_proj + (quantized_x_hard - x_proj).detach() |
| 100 | + indices = (quantized_x > 0).long() |
| 101 | + probs = indices.float().mean(dim=(0, 2, 3, 4)) |
| 102 | + entropy = - (probs * torch.log(probs.clamp(min=1e-8)) + (1 - probs) * torch.log((1 - probs).clamp(min=1e-8))) |
| 103 | + entropy_loss = -entropy.mean() * self.entropy_loss_weight |
| 104 | + return quantized_x, indices, entropy_loss |
| 105 | + |
| 106 | +class VideoVAEEncoderBlock(nn.Module): |
| 107 | + """Standard VAE encoder block for downsampling.""" |
| 108 | + def __init__(self, in_ch, out_ch): |
| 109 | + super().__init__() |
| 110 | + self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1) |
| 111 | + self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1) |
| 112 | + self.pool = nn.MaxPool3d(kernel_size=2, stride=2) |
| 113 | + self.norm = nn.BatchNorm3d(out_ch) |
| 114 | + self.act = nn.GELU() |
| 115 | + |
| 116 | + def forward(self, x): |
| 117 | + h = self.act(self.norm(self.conv1(x))) |
| 118 | + h = self.act(self.norm(self.conv2(h))) |
| 119 | + return self.pool(h) |
| 120 | + |
| 121 | +class PyramidalLFQBlock(nn.Module): |
| 122 | + """A block in the pyramidal upsampler: upsample -> fuse -> text-attend -> quantize.""" |
| 123 | + def __init__(self, in_ch, skip_ch, out_ch, text_embed_dim, quant_emb_dim): |
| 124 | + super().__init__() |
| 125 | + self.upsample = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1) |
| 126 | + self.conv = nn.Conv3d(out_ch + skip_ch, out_ch, kernel_size=3, padding=1) |
| 127 | + self.text_cross_attn = TextVideoCrossAttention(out_ch, text_embed_dim) |
| 128 | + self.lfq = ProjectedLFQ(out_ch, quant_channels=quant_emb_dim) |
| 129 | + self.norm = nn.BatchNorm3d(out_ch) |
| 130 | + self.act = nn.GELU() |
| 131 | + |
| 132 | + def forward(self, x, skip, text_embedding): |
| 133 | + x_up = self.upsample(x) |
| 134 | + x_fused = self.act(self.norm(self.conv(torch.cat([x_up, skip], dim=1)))) |
| 135 | + h_attn = x_fused + self.text_cross_attn(x_fused, text_embedding) |
| 136 | + q, indices, entropy_loss = self.lfq(h_attn) |
| 137 | + return h_attn, q, indices, entropy_loss |
| 138 | + |
| 139 | +class VideoVAEDecoderBlock(nn.Module): |
| 140 | + """Standard VAE decoder block for upsampling.""" |
| 141 | + def __init__(self, in_ch, out_ch): |
| 142 | + super().__init__() |
| 143 | + self.upsample = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1) |
| 144 | + self.conv = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1) |
| 145 | + self.norm = nn.BatchNorm3d(out_ch) |
| 146 | + self.act = nn.GELU() |
| 147 | + |
| 148 | + def forward(self, x): |
| 149 | + h = self.act(self.norm(self.upsample(x))) |
| 150 | + return self.act(self.norm(self.conv(h))) |
| 151 | + |
| 152 | +# ============================================================================== |
| 153 | +# 4. PRIMARY VideoVAE MODEL |
| 154 | +# ============================================================================== |
| 155 | + |
| 156 | +class VideoVAE(nn.Module): |
| 157 | + """ |
| 158 | + A modular, text-conditioned Video VAE with a Pyramidal LFQ structure |
| 159 | + and multiple perception-based losses for high-quality synthesis. |
| 160 | + """ |
| 161 | + def __init__(self, cfg: VideoVAEConfig, device="cuda"): |
| 162 | + super().__init__() |
| 163 | + self.cfg = cfg |
| 164 | + self.device = device |
| 165 | + |
| 166 | + # --- Sub-models (Text, Perception) --- |
| 167 | + self.text_encoder = QwenVLTextEncoder(device=device) |
| 168 | + text_embed_dim = self.text_encoder.model.config.hidden_size |
| 169 | + if self.training: # Only load DINOv2 if we are in training mode |
| 170 | + self.dino_extractor = DINOv2Extractor(device=device) |
| 171 | + |
| 172 | + # --- VAE Encoder --- |
| 173 | + self.enc_blocks = nn.ModuleList() |
| 174 | + chs = [cfg.base_ch * (2**i) for i in range(cfg.num_blocks)] |
| 175 | + current_ch = cfg.in_channels |
| 176 | + for ch in chs: |
| 177 | + self.enc_blocks.append(VideoVAEEncoderBlock(current_ch, ch)) |
| 178 | + current_ch = ch |
| 179 | + |
| 180 | + # --- Pyramidal LFQ Upsampler --- |
| 181 | + rev_channels = list(reversed(chs)) |
| 182 | + self.pyramid_blocks = nn.ModuleList() |
| 183 | + for i in range(2): # 2 stages for 4x total upscaling |
| 184 | + self.pyramid_blocks.append( |
| 185 | + PyramidalLFQBlock(rev_channels[i], rev_channels[i+1], rev_channels[i+1], text_embed_dim, cfg.quant_emb_dim) |
| 186 | + ) |
| 187 | + |
| 188 | + # --- VAE Decoder --- |
| 189 | + self.dec_blocks = nn.ModuleList() |
| 190 | + decoder_channels = [chs[1], chs[0]] |
| 191 | + for i in range(len(decoder_channels)): |
| 192 | + in_ch = decoder_channels[i] |
| 193 | + out_ch = decoder_channels[i+1] if i + 1 < len(decoder_channels) else cfg.base_ch |
| 194 | + self.dec_blocks.append(VideoVAEDecoderBlock(in_ch, out_ch)) |
| 195 | + self.out_conv = nn.Conv3d(cfg.base_ch, cfg.in_channels, 1) |
| 196 | + |
| 197 | + # --- Loss-specific Modules --- |
| 198 | + codebook_size = 2**cfg.quant_emb_dim |
| 199 | + self.quant_embedding = nn.Embedding(codebook_size, text_embed_dim) |
| 200 | + self.to_quant_logits = nn.Linear(text_embed_dim, codebook_size) |
| 201 | + quant_pooled_dim = chs[2] + chs[1] |
| 202 | + self.quant_proj = nn.Linear(quant_pooled_dim, cfg.alignment_dim) |
| 203 | + self.text_proj_for_quant = nn.Linear(text_embed_dim, cfg.alignment_dim) |
| 204 | + |
| 205 | + def forward(self, x: torch.Tensor, text_prompts: List[str]) -> Dict[str, torch.Tensor]: |
| 206 | + """ |
| 207 | + Core inference path. Encodes, quantizes via pyramid, and decodes. |
| 208 | + Returns all intermediate products needed for loss calculation. |
| 209 | + """ |
| 210 | + text_embedding = self.text_encoder(text_prompts) |
| 211 | + |
| 212 | + encoder_features = [] |
| 213 | + h = x |
| 214 | + for block in self.enc_blocks: |
| 215 | + h = block(h) |
| 216 | + encoder_features.append(h) |
| 217 | + |
| 218 | + rev_features = list(reversed(encoder_features)) |
| 219 | + h = rev_features[0] |
| 220 | + pyramid_outputs = {'q': [], 'indices': [], 'entropies': []} |
| 221 | + for i, block in enumerate(self.pyramid_blocks): |
| 222 | + h, q, indices, entropy = block(h, rev_features[i + 1], text_embedding) |
| 223 | + pyramid_outputs['q'].append(q) |
| 224 | + pyramid_outputs['indices'].append(indices) |
| 225 | + pyramid_outputs['entropies'].append(entropy) |
| 226 | + |
| 227 | + dec_in = h |
| 228 | + for block in self.dec_blocks: |
| 229 | + dec_in = block(dec_in) |
| 230 | + reconstruction = torch.tanh(self.out_conv(dec_in)) |
| 231 | + |
| 232 | + return { |
| 233 | + "reconstruction": reconstruction, |
| 234 | + "text_embedding": text_embedding, |
| 235 | + "pyramid_outputs": pyramid_outputs |
| 236 | + } |
| 237 | + |
| 238 | + def calculate_losses(self, original_video: torch.Tensor, forward_outputs: Dict) -> Dict: |
| 239 | + """ |
| 240 | + Calculates all training-specific losses. This method should only be |
| 241 | + called during the training loop. |
| 242 | + """ |
| 243 | + if not self.training: |
| 244 | + raise RuntimeError("calculate_losses() should only be called in training mode.") |
| 245 | + |
| 246 | + # Unpack forward pass results |
| 247 | + recon = forward_outputs["reconstruction"] |
| 248 | + text_emb = forward_outputs["text_embedding"] |
| 249 | + pyramid_out = forward_outputs["pyramid_outputs"] |
| 250 | + all_q, all_indices, all_entropies = pyramid_out['q'], pyramid_out['indices'], pyramid_out['entropies'] |
| 251 | + |
| 252 | + # 1. Reconstruction Loss |
| 253 | + recon_loss = F.mse_loss(recon, original_video) |
| 254 | + |
| 255 | + # 2. Entropy Loss |
| 256 | + entropy_loss = sum(all_entropies) |
| 257 | + |
| 258 | + # 3. P(Q|text) Likelihood Loss |
| 259 | + B = text_emb.size(0) |
| 260 | + seqs = [idx.view(B, self.cfg.quant_emb_dim, -1) for idx in all_indices] |
| 261 | + full_seq_bits = torch.cat(seqs, dim=2).permute(0, 2, 1) |
| 262 | + powers_of_2 = (2**torch.arange(self.cfg.quant_emb_dim, device=self.device)).float() |
| 263 | + quant_token_ids = (full_seq_bits * powers_of_2).sum(dim=2).long() |
| 264 | + quant_embeds = self.quant_embedding(quant_token_ids) |
| 265 | + combined_embeds = torch.cat([text_emb, quant_embeds], dim=1) |
| 266 | + with torch.no_grad(): |
| 267 | + qwen_outputs = self.text_encoder.model(inputs_embeds=combined_embeds, output_hidden_states=True) |
| 268 | + last_hidden = qwen_outputs.hidden_states[-1][:, text_emb.shape[1] - 1:-1, :] |
| 269 | + pred_logits = self.to_quant_logits(last_hidden) |
| 270 | + likelihood_loss = F.cross_entropy(pred_logits.reshape(-1, pred_logits.size(-1)), quant_token_ids.reshape(-1)) |
| 271 | + |
| 272 | + # 4. Quantized Vector-Text Alignment Loss |
| 273 | + q_pooled = [F.adaptive_avg_pool3d(q, 1).view(B, -1) for q in all_q] |
| 274 | + q_pooled_cat = torch.cat(q_pooled, dim=1) |
| 275 | + text_pooled = text_emb.mean(dim=1) |
| 276 | + q_aligned = self.quant_proj(q_pooled_cat) |
| 277 | + text_aligned = self.text_proj_for_quant(text_pooled) |
| 278 | + quant_align_loss = F.cosine_embedding_loss(q_aligned, text_aligned, torch.ones(B, device=self.device)) |
| 279 | + |
| 280 | + # 5. DINOv2 Perceptual Loss (KL Divergence) |
| 281 | + orig_dino_feats = self.dino_extractor(original_video) |
| 282 | + recon_dino_feats = self.dino_extractor(recon) |
| 283 | + p = F.softmax(orig_dino_feats, dim=-1) |
| 284 | + q = F.log_softmax(recon_dino_feats, dim=-1) |
| 285 | + dino_loss = F.kl_div(q, p, reduction='batchmean') |
| 286 | + |
| 287 | + # --- Final Weighted Sum --- |
| 288 | + total_loss = (recon_loss + entropy_loss + |
| 289 | + self.cfg.likelihood_loss_weight * likelihood_loss + |
| 290 | + self.cfg.quant_align_loss_weight * quant_align_loss + |
| 291 | + self.cfg.dino_loss_weight * dino_loss) |
| 292 | + |
| 293 | + return { |
| 294 | + "total_loss": total_loss, "reconstruction_loss": recon_loss, |
| 295 | + "entropy_loss": entropy_loss, "likelihood_loss": likelihood_loss, |
| 296 | + "quant_alignment_loss": quant_align_loss, "dino_perceptual_loss": dino_loss |
| 297 | + } |
| 298 | + |
| 299 | +# ============================================================================== |
| 300 | +# 5. EXAMPLE USAGE |
| 301 | +# ============================================================================== |
| 302 | +if __name__ == '__main__': |
| 303 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 304 | + if device == "cpu": print("WARNING: Running on CPU. This will be extremely slow.") |
| 305 | + |
| 306 | + try: |
| 307 | + config = VideoVAEConfig(quant_emb_dim=16) # Set LFQ size to 16 |
| 308 | + model = VideoVAE(config, device=device).to(device) |
| 309 | + |
| 310 | + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 311 | + print("-" * 40) |
| 312 | + print(f"Trainable model parameters: {trainable_params:,}") |
| 313 | + print("(This should NOT include frozen DINOv2 or Qwen-VL models)") |
| 314 | + print("-" * 40) |
| 315 | + |
| 316 | + # --- SIMULATED TRAINING STEP --- |
| 317 | + print("\n--- 1. Simulating Training Step ---") |
| 318 | + model.train() # Set model to training mode |
| 319 | + batch_size = 2 |
| 320 | + video_input = torch.randn(batch_size, 3, 16, 64, 64).to(device) |
| 321 | + prompts = ["A stunning sunrise over a calm ocean.", "A busy city street at night with neon lights."] |
| 322 | + |
| 323 | + # In a real training loop, this would be inside the loop |
| 324 | + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) |
| 325 | + optimizer.zero_grad() |
| 326 | + |
| 327 | + forward_outputs = model(video_input, text_prompts=prompts) |
| 328 | + losses = model.calculate_losses(video_input, forward_outputs) |
| 329 | + |
| 330 | + # Backpropagation |
| 331 | + losses["total_loss"].backward() |
| 332 | + optimizer.step() |
| 333 | + |
| 334 | + print("Training step successful. Losses calculated:") |
| 335 | + for name, value in losses.items(): print(f" - {name:<25}: {value.item():.4f}") |
| 336 | + |
| 337 | + # --- SIMULATED INFERENCE STEP --- |
| 338 | + print("\n--- 2. Simulating Inference Step ---") |
| 339 | + model.eval() # Set model to evaluation mode |
| 340 | + with torch.no_grad(): |
| 341 | + # Notice we only call the forward pass and don't need the loss function |
| 342 | + inference_outputs = model(video_input, text_prompts=prompts) |
| 343 | + reconstructed_video = inference_outputs["reconstruction"] |
| 344 | + |
| 345 | + print("Inference step successful.") |
| 346 | + print("Input Video Shape: ", video_input.shape) |
| 347 | + print("Reconstructed Video Shape: ", reconstructed_video.shape) |
| 348 | + |
| 349 | + except Exception as e: |
| 350 | + print(f"\n--- ❌ An Error Occurred ---") |
| 351 | + print(f"Error: {e}") |
| 352 | + if "out of memory" in str(e).lower(): |
| 353 | + print("\n💡 Suggestion: CUDA Out-of-Memory. Try reducing `base_ch`, `num_blocks`, or input resolution.") |
0 commit comments