Skip to content

Commit e45dab4

Browse files
committed
Fix vae decoding
1 parent a10d0a6 commit e45dab4

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

fastvideo/pipelines/stages/decoding.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def forward(
8282
if latents is None:
8383
raise ValueError("Latents must be provided")
8484

85+
print(f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: {latents.float().sum().item():.6f}, shape: {latents.shape}, dtype: {latents.dtype}")
86+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
87+
f.write(f"[FASTVIDEO VAE DEBUG] Before scaling/shifting - latents sum: {latents.float().sum().item():.6f}, shape: {latents.shape}, dtype: {latents.dtype}\n")
88+
8589
# Skip decoding if output type is latent
8690
if fastvideo_args.output_type == "latent":
8791
image = latents
@@ -92,18 +96,40 @@ def forward(
9296
vae_autocast_enabled = (vae_dtype != torch.float32
9397
) and not fastvideo_args.disable_autocast
9498

95-
# TEMPORARY: Handle diffusers VAE compatibility
96-
if hasattr(self.vae, 'scaling_factor'):
99+
# Apply latents normalization for Cosmos VAE (matching diffusers implementation)
100+
# Source: /workspace/diffusers/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py:1000-1010
101+
if hasattr(self.vae, 'config') and hasattr(self.vae.config, 'latents_mean') and hasattr(self.vae.config, 'latents_std'):
102+
# Get scheduler for sigma_data
103+
pipeline = self.pipeline() if self.pipeline else None
104+
sigma_data = 1.0 # default
105+
if pipeline and hasattr(pipeline, 'modules') and 'scheduler' in pipeline.modules:
106+
scheduler = pipeline.modules['scheduler']
107+
if hasattr(scheduler, 'config') and hasattr(scheduler.config, 'sigma_data'):
108+
sigma_data = scheduler.config.sigma_data
109+
110+
latents_mean = (
111+
torch.tensor(self.vae.config.latents_mean)
112+
.view(1, self.vae.config.z_dim, 1, 1, 1)
113+
.to(latents.device, latents.dtype)
114+
)
115+
latents_std = (
116+
torch.tensor(self.vae.config.latents_std)
117+
.view(1, self.vae.config.z_dim, 1, 1, 1)
118+
.to(latents.device, latents.dtype)
119+
)
120+
latents = latents * latents_std / sigma_data + latents_mean
121+
# Fallback to scaling_factor for other VAE types
122+
elif hasattr(self.vae, 'scaling_factor'):
97123
if isinstance(self.vae.scaling_factor, torch.Tensor):
98124
latents = latents / self.vae.scaling_factor.to(
99125
latents.device, latents.dtype)
100126
else:
101127
latents = latents / self.vae.scaling_factor
102128
elif hasattr(self.vae, 'config') and hasattr(self.vae.config, 'scaling_factor'):
103-
# Fallback to config scaling factor for diffusers VAE
129+
# Fallback to config scaling factor for other diffusers VAEs
104130
latents = latents / self.vae.config.scaling_factor
105131

106-
# Apply shifting if needed
132+
# Apply shifting if needed (for other VAE types)
107133
if (hasattr(self.vae, "shift_factor")
108134
and self.vae.shift_factor is not None):
109135
if isinstance(self.vae.shift_factor, torch.Tensor):
@@ -112,6 +138,10 @@ def forward(
112138
else:
113139
latents += self.vae.shift_factor
114140

141+
print(f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: {latents.float().sum().item():.6f}")
142+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
143+
f.write(f"[FASTVIDEO VAE DEBUG] After scaling/shifting - latents sum: {latents.float().sum().item():.6f}\n")
144+
115145
# Decode latents
116146
with torch.autocast(device_type="cuda",
117147
dtype=vae_dtype,
@@ -132,9 +162,17 @@ def forward(
132162
# FastVideo VAE returns tensor directly
133163
image = decode_output
134164

165+
print(f"[FASTVIDEO VAE DEBUG] After decode - image sum: {image.float().sum().item():.6f}, shape: {image.shape}")
166+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
167+
f.write(f"[FASTVIDEO VAE DEBUG] After decode - image sum: {image.float().sum().item():.6f}, shape: {image.shape}\n")
168+
135169
# Normalize image to [0, 1] range
136170
image = (image / 2 + 0.5).clamp(0, 1)
137171

172+
print(f"[FASTVIDEO VAE DEBUG] After normalization - image sum: {image.float().sum().item():.6f}")
173+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
174+
f.write(f"[FASTVIDEO VAE DEBUG] After normalization - image sum: {image.float().sum().item():.6f}\n")
175+
138176
# Convert to CPU float32 for compatibility
139177
image = image.cpu().float()
140178

0 commit comments

Comments
 (0)