Skip to content

Commit 4b7224f

Browse files
committed
aded
2 parents 1f9ab80 + 30aef0a commit 4b7224f

File tree

1 file changed

+353
-0
lines changed

1 file changed

+353
-0
lines changed

VAE_model.py

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
# video_vae_modular_final.py
2+
3+
# ==============================================================================
4+
# 1. IMPORTS & CONFIGURATION
5+
# ==============================================================================
6+
import torch
7+
import torch.nn as nn
8+
import torch.nn.functional as F
9+
from transformers import AutoModel, AutoProcessor, AutoModelForCausalLM
10+
from typing import List, Dict
11+
from dataclasses import dataclass
12+
13+
@dataclass
14+
class VideoVAEConfig:
15+
in_channels: int = 3
16+
base_ch: int = 64
17+
num_blocks: int = 4
18+
quant_emb_dim: int = 16
19+
alignment_dim: int = 256
20+
quant_align_loss_weight: float = 0.1
21+
likelihood_loss_weight: float = 0.2
22+
dino_loss_weight: float = 0.25
23+
24+
# ==============================================================================
25+
# 2. PERCEPTUAL & TEXT MODULES
26+
# ==============================================================================
27+
28+
class DINOv2Extractor(nn.Module):
29+
"""
30+
A frozen DINOv2 model to extract perceptual features from video frames.
31+
"""
32+
def __init__(self, device="cuda"):
33+
super().__init__()
34+
self.device = device
35+
model_name = "facebook/dinov2-base"
36+
print("Loading DINOv2 model and processor...")
37+
self.processor = AutoProcessor.from_pretrained(model_name)
38+
self.model = AutoModel.from_pretrained(model_name).to(self.device).eval()
39+
for param in self.model.parameters():
40+
param.requires_grad = False
41+
print("DINOv2 loaded and frozen successfully. 🦖")
42+
43+
def forward(self, video_tensor: torch.Tensor) -> torch.Tensor:
44+
b, c, t, h, w = video_tensor.shape
45+
video_tensor = video_tensor.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
46+
inputs = self.processor(images=video_tensor, return_tensors="pt", do_rescale=False).to(self.device)
47+
with torch.no_grad():
48+
outputs = self.model(**inputs)
49+
# Return the features of the [CLS] token
50+
return outputs.last_hidden_state[:, 0].view(b, t, -1)
51+
52+
class QwenVLTextEncoder(nn.Module):
53+
"""A frozen Qwen-VL model to extract text embeddings."""
54+
def __init__(self, device="cuda"):
55+
super().__init__()
56+
model_id = "Qwen/Qwen2.5-VL-Instruct"
57+
self.device = device
58+
print("Loading Qwen-VL model and processor...")
59+
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
60+
self.model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto", trust_remote_code=True).eval()
61+
for param in self.model.parameters(): param.requires_grad = False
62+
print("Qwen-VL loaded and frozen successfully. 🥶")
63+
64+
def forward(self, text_prompts: list[str]):
65+
messages = [[{"role": "user", "content": [{"type": "text", "text": prompt}]}] for prompt in text_prompts]
66+
text_inputs = self.processor(conversations=messages, return_tensors="pt", padding=True).to(self.model.device)
67+
with torch.no_grad():
68+
outputs = self.model(**text_inputs, output_hidden_states=True)
69+
return outputs.hidden_states[-1].to(self.device)
70+
71+
class TextVideoCrossAttention(nn.Module):
72+
"""Performs cross-attention between video features (Q) and text features (K,V)."""
73+
def __init__(self, video_channels, text_embed_dim):
74+
super().__init__()
75+
self.q_proj, self.k_proj, self.v_proj = nn.Linear(video_channels, video_channels), nn.Linear(text_embed_dim, video_channels), nn.Linear(text_embed_dim, video_channels)
76+
self.out_proj = nn.Linear(video_channels, video_channels)
77+
78+
def forward(self, video_feat, text_embedding):
79+
B, C, T, H, W = video_feat.shape
80+
video_seq = video_feat.permute(0, 2, 3, 4, 1).reshape(B, T * H * W, C)
81+
q, k, v = self.q_proj(video_seq), self.k_proj(text_embedding), self.v_proj(text_embedding)
82+
attn_output = F.scaled_dot_product_attention(q.unsqueeze(1), k, v).squeeze(1)
83+
return self.out_proj(attn_output).reshape(B, T, H, W, C).permute(0, 4, 1, 2, 3)
84+
85+
# ==============================================================================
86+
# 3. CORE ARCHITECTURAL BLOCKS
87+
# ==============================================================================
88+
89+
class ProjectedLFQ(nn.Module):
90+
"""Projects features and quantizes them, returning an entropy loss."""
91+
def __init__(self, in_channels, quant_channels, entropy_loss_weight=0.1):
92+
super().__init__()
93+
self.project = nn.Conv3d(in_channels, quant_channels, 1)
94+
self.entropy_loss_weight = entropy_loss_weight
95+
96+
def forward(self, x):
97+
x_proj = self.project(x)
98+
quantized_x_hard = torch.where(x_proj > 0, 1.0, -1.0)
99+
quantized_x = x_proj + (quantized_x_hard - x_proj).detach()
100+
indices = (quantized_x > 0).long()
101+
probs = indices.float().mean(dim=(0, 2, 3, 4))
102+
entropy = - (probs * torch.log(probs.clamp(min=1e-8)) + (1 - probs) * torch.log((1 - probs).clamp(min=1e-8)))
103+
entropy_loss = -entropy.mean() * self.entropy_loss_weight
104+
return quantized_x, indices, entropy_loss
105+
106+
class VideoVAEEncoderBlock(nn.Module):
107+
"""Standard VAE encoder block for downsampling."""
108+
def __init__(self, in_ch, out_ch):
109+
super().__init__()
110+
self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1)
111+
self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1)
112+
self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
113+
self.norm = nn.BatchNorm3d(out_ch)
114+
self.act = nn.GELU()
115+
116+
def forward(self, x):
117+
h = self.act(self.norm(self.conv1(x)))
118+
h = self.act(self.norm(self.conv2(h)))
119+
return self.pool(h)
120+
121+
class PyramidalLFQBlock(nn.Module):
122+
"""A block in the pyramidal upsampler: upsample -> fuse -> text-attend -> quantize."""
123+
def __init__(self, in_ch, skip_ch, out_ch, text_embed_dim, quant_emb_dim):
124+
super().__init__()
125+
self.upsample = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
126+
self.conv = nn.Conv3d(out_ch + skip_ch, out_ch, kernel_size=3, padding=1)
127+
self.text_cross_attn = TextVideoCrossAttention(out_ch, text_embed_dim)
128+
self.lfq = ProjectedLFQ(out_ch, quant_channels=quant_emb_dim)
129+
self.norm = nn.BatchNorm3d(out_ch)
130+
self.act = nn.GELU()
131+
132+
def forward(self, x, skip, text_embedding):
133+
x_up = self.upsample(x)
134+
x_fused = self.act(self.norm(self.conv(torch.cat([x_up, skip], dim=1))))
135+
h_attn = x_fused + self.text_cross_attn(x_fused, text_embedding)
136+
q, indices, entropy_loss = self.lfq(h_attn)
137+
return h_attn, q, indices, entropy_loss
138+
139+
class VideoVAEDecoderBlock(nn.Module):
140+
"""Standard VAE decoder block for upsampling."""
141+
def __init__(self, in_ch, out_ch):
142+
super().__init__()
143+
self.upsample = nn.ConvTranspose3d(in_ch, out_ch, kernel_size=4, stride=2, padding=1)
144+
self.conv = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1)
145+
self.norm = nn.BatchNorm3d(out_ch)
146+
self.act = nn.GELU()
147+
148+
def forward(self, x):
149+
h = self.act(self.norm(self.upsample(x)))
150+
return self.act(self.norm(self.conv(h)))
151+
152+
# ==============================================================================
153+
# 4. PRIMARY VideoVAE MODEL
154+
# ==============================================================================
155+
156+
class VideoVAE(nn.Module):
157+
"""
158+
A modular, text-conditioned Video VAE with a Pyramidal LFQ structure
159+
and multiple perception-based losses for high-quality synthesis.
160+
"""
161+
def __init__(self, cfg: VideoVAEConfig, device="cuda"):
162+
super().__init__()
163+
self.cfg = cfg
164+
self.device = device
165+
166+
# --- Sub-models (Text, Perception) ---
167+
self.text_encoder = QwenVLTextEncoder(device=device)
168+
text_embed_dim = self.text_encoder.model.config.hidden_size
169+
if self.training: # Only load DINOv2 if we are in training mode
170+
self.dino_extractor = DINOv2Extractor(device=device)
171+
172+
# --- VAE Encoder ---
173+
self.enc_blocks = nn.ModuleList()
174+
chs = [cfg.base_ch * (2**i) for i in range(cfg.num_blocks)]
175+
current_ch = cfg.in_channels
176+
for ch in chs:
177+
self.enc_blocks.append(VideoVAEEncoderBlock(current_ch, ch))
178+
current_ch = ch
179+
180+
# --- Pyramidal LFQ Upsampler ---
181+
rev_channels = list(reversed(chs))
182+
self.pyramid_blocks = nn.ModuleList()
183+
for i in range(2): # 2 stages for 4x total upscaling
184+
self.pyramid_blocks.append(
185+
PyramidalLFQBlock(rev_channels[i], rev_channels[i+1], rev_channels[i+1], text_embed_dim, cfg.quant_emb_dim)
186+
)
187+
188+
# --- VAE Decoder ---
189+
self.dec_blocks = nn.ModuleList()
190+
decoder_channels = [chs[1], chs[0]]
191+
for i in range(len(decoder_channels)):
192+
in_ch = decoder_channels[i]
193+
out_ch = decoder_channels[i+1] if i + 1 < len(decoder_channels) else cfg.base_ch
194+
self.dec_blocks.append(VideoVAEDecoderBlock(in_ch, out_ch))
195+
self.out_conv = nn.Conv3d(cfg.base_ch, cfg.in_channels, 1)
196+
197+
# --- Loss-specific Modules ---
198+
codebook_size = 2**cfg.quant_emb_dim
199+
self.quant_embedding = nn.Embedding(codebook_size, text_embed_dim)
200+
self.to_quant_logits = nn.Linear(text_embed_dim, codebook_size)
201+
quant_pooled_dim = chs[2] + chs[1]
202+
self.quant_proj = nn.Linear(quant_pooled_dim, cfg.alignment_dim)
203+
self.text_proj_for_quant = nn.Linear(text_embed_dim, cfg.alignment_dim)
204+
205+
def forward(self, x: torch.Tensor, text_prompts: List[str]) -> Dict[str, torch.Tensor]:
206+
"""
207+
Core inference path. Encodes, quantizes via pyramid, and decodes.
208+
Returns all intermediate products needed for loss calculation.
209+
"""
210+
text_embedding = self.text_encoder(text_prompts)
211+
212+
encoder_features = []
213+
h = x
214+
for block in self.enc_blocks:
215+
h = block(h)
216+
encoder_features.append(h)
217+
218+
rev_features = list(reversed(encoder_features))
219+
h = rev_features[0]
220+
pyramid_outputs = {'q': [], 'indices': [], 'entropies': []}
221+
for i, block in enumerate(self.pyramid_blocks):
222+
h, q, indices, entropy = block(h, rev_features[i + 1], text_embedding)
223+
pyramid_outputs['q'].append(q)
224+
pyramid_outputs['indices'].append(indices)
225+
pyramid_outputs['entropies'].append(entropy)
226+
227+
dec_in = h
228+
for block in self.dec_blocks:
229+
dec_in = block(dec_in)
230+
reconstruction = torch.tanh(self.out_conv(dec_in))
231+
232+
return {
233+
"reconstruction": reconstruction,
234+
"text_embedding": text_embedding,
235+
"pyramid_outputs": pyramid_outputs
236+
}
237+
238+
def calculate_losses(self, original_video: torch.Tensor, forward_outputs: Dict) -> Dict:
239+
"""
240+
Calculates all training-specific losses. This method should only be
241+
called during the training loop.
242+
"""
243+
if not self.training:
244+
raise RuntimeError("calculate_losses() should only be called in training mode.")
245+
246+
# Unpack forward pass results
247+
recon = forward_outputs["reconstruction"]
248+
text_emb = forward_outputs["text_embedding"]
249+
pyramid_out = forward_outputs["pyramid_outputs"]
250+
all_q, all_indices, all_entropies = pyramid_out['q'], pyramid_out['indices'], pyramid_out['entropies']
251+
252+
# 1. Reconstruction Loss
253+
recon_loss = F.mse_loss(recon, original_video)
254+
255+
# 2. Entropy Loss
256+
entropy_loss = sum(all_entropies)
257+
258+
# 3. P(Q|text) Likelihood Loss
259+
B = text_emb.size(0)
260+
seqs = [idx.view(B, self.cfg.quant_emb_dim, -1) for idx in all_indices]
261+
full_seq_bits = torch.cat(seqs, dim=2).permute(0, 2, 1)
262+
powers_of_2 = (2**torch.arange(self.cfg.quant_emb_dim, device=self.device)).float()
263+
quant_token_ids = (full_seq_bits * powers_of_2).sum(dim=2).long()
264+
quant_embeds = self.quant_embedding(quant_token_ids)
265+
combined_embeds = torch.cat([text_emb, quant_embeds], dim=1)
266+
with torch.no_grad():
267+
qwen_outputs = self.text_encoder.model(inputs_embeds=combined_embeds, output_hidden_states=True)
268+
last_hidden = qwen_outputs.hidden_states[-1][:, text_emb.shape[1] - 1:-1, :]
269+
pred_logits = self.to_quant_logits(last_hidden)
270+
likelihood_loss = F.cross_entropy(pred_logits.reshape(-1, pred_logits.size(-1)), quant_token_ids.reshape(-1))
271+
272+
# 4. Quantized Vector-Text Alignment Loss
273+
q_pooled = [F.adaptive_avg_pool3d(q, 1).view(B, -1) for q in all_q]
274+
q_pooled_cat = torch.cat(q_pooled, dim=1)
275+
text_pooled = text_emb.mean(dim=1)
276+
q_aligned = self.quant_proj(q_pooled_cat)
277+
text_aligned = self.text_proj_for_quant(text_pooled)
278+
quant_align_loss = F.cosine_embedding_loss(q_aligned, text_aligned, torch.ones(B, device=self.device))
279+
280+
# 5. DINOv2 Perceptual Loss (KL Divergence)
281+
orig_dino_feats = self.dino_extractor(original_video)
282+
recon_dino_feats = self.dino_extractor(recon)
283+
p = F.softmax(orig_dino_feats, dim=-1)
284+
q = F.log_softmax(recon_dino_feats, dim=-1)
285+
dino_loss = F.kl_div(q, p, reduction='batchmean')
286+
287+
# --- Final Weighted Sum ---
288+
total_loss = (recon_loss + entropy_loss +
289+
self.cfg.likelihood_loss_weight * likelihood_loss +
290+
self.cfg.quant_align_loss_weight * quant_align_loss +
291+
self.cfg.dino_loss_weight * dino_loss)
292+
293+
return {
294+
"total_loss": total_loss, "reconstruction_loss": recon_loss,
295+
"entropy_loss": entropy_loss, "likelihood_loss": likelihood_loss,
296+
"quant_alignment_loss": quant_align_loss, "dino_perceptual_loss": dino_loss
297+
}
298+
299+
# ==============================================================================
300+
# 5. EXAMPLE USAGE
301+
# ==============================================================================
302+
if __name__ == '__main__':
303+
device = "cuda" if torch.cuda.is_available() else "cpu"
304+
if device == "cpu": print("WARNING: Running on CPU. This will be extremely slow.")
305+
306+
try:
307+
config = VideoVAEConfig(quant_emb_dim=16) # Set LFQ size to 16
308+
model = VideoVAE(config, device=device).to(device)
309+
310+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
311+
print("-" * 40)
312+
print(f"Trainable model parameters: {trainable_params:,}")
313+
print("(This should NOT include frozen DINOv2 or Qwen-VL models)")
314+
print("-" * 40)
315+
316+
# --- SIMULATED TRAINING STEP ---
317+
print("\n--- 1. Simulating Training Step ---")
318+
model.train() # Set model to training mode
319+
batch_size = 2
320+
video_input = torch.randn(batch_size, 3, 16, 64, 64).to(device)
321+
prompts = ["A stunning sunrise over a calm ocean.", "A busy city street at night with neon lights."]
322+
323+
# In a real training loop, this would be inside the loop
324+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
325+
optimizer.zero_grad()
326+
327+
forward_outputs = model(video_input, text_prompts=prompts)
328+
losses = model.calculate_losses(video_input, forward_outputs)
329+
330+
# Backpropagation
331+
losses["total_loss"].backward()
332+
optimizer.step()
333+
334+
print("Training step successful. Losses calculated:")
335+
for name, value in losses.items(): print(f" - {name:<25}: {value.item():.4f}")
336+
337+
# --- SIMULATED INFERENCE STEP ---
338+
print("\n--- 2. Simulating Inference Step ---")
339+
model.eval() # Set model to evaluation mode
340+
with torch.no_grad():
341+
# Notice we only call the forward pass and don't need the loss function
342+
inference_outputs = model(video_input, text_prompts=prompts)
343+
reconstructed_video = inference_outputs["reconstruction"]
344+
345+
print("Inference step successful.")
346+
print("Input Video Shape: ", video_input.shape)
347+
print("Reconstructed Video Shape: ", reconstructed_video.shape)
348+
349+
except Exception as e:
350+
print(f"\n--- ❌ An Error Occurred ---")
351+
print(f"Error: {e}")
352+
if "out of memory" in str(e).lower():
353+
print("\n💡 Suggestion: CUDA Out-of-Memory. Try reducing `base_ch`, `num_blocks`, or input resolution.")

0 commit comments

Comments
 (0)