Skip to content

Commit ffb8b0a

Browse files
committed
Fix latent preparation
1 parent 1edf638 commit ffb8b0a

File tree

3 files changed

+96
-19
lines changed

3 files changed

+96
-19
lines changed

fastvideo/pipelines/stages/input_validation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,12 @@ def _generate_seeds(self, batch: ForwardBatch,
3636
assert seed is not None
3737
seeds = [seed + i for i in range(num_videos_per_prompt)]
3838
batch.seeds = seeds
39-
# Peiyuan: using GPU seed will cause A100 and H100 to generate different results...
39+
# Use device-specific generators to match diffusers behavior
40+
# diffusers uses torch.Generator(device=device).manual_seed()
41+
from fastvideo.distributed import get_local_torch_device
42+
device = get_local_torch_device()
4043
batch.generator = [
41-
torch.Generator("cpu").manual_seed(seed) for seed in seeds
44+
torch.Generator(device=device).manual_seed(seed) for seed in seeds
4245
]
4346

4447
def forward(

fastvideo/pipelines/stages/latent_preparation.py

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -176,16 +176,10 @@ def forward(
176176
# 704/88 = 8, 1280/160 = 8, so spatial_scale = 8
177177
vae_scale_factor_spatial = 8
178178

179-
# For temporal: Need 93 frames -> 7 latent frames
180-
# Using formula: (num_frames - 1) // temporal_scale + 1 = 7
181-
# So: (93 - 1) // temporal_scale + 1 = 7
182-
# 92 // temporal_scale + 1 = 7
183-
# 92 // temporal_scale = 6
184-
# temporal_scale = 92 // 6 = 15.33 -> try 13
185-
# Check: (93-1)//13 + 1 = 92//13 + 1 = 7 + 1 = 8 (too high)
186-
# Try 12: (93-1)//12 + 1 = 92//12 + 1 = 7 + 1 = 8 (too high)
187-
# Try 15: (93-1)//15 + 1 = 92//15 + 1 = 6 + 1 = 7 ✓
188-
vae_scale_factor_temporal = 15
179+
# For temporal: Use the same scale factor as diffusers Cosmos pipeline
180+
# Diffusers uses vae_scale_factor_temporal = 4 as default
181+
# For 21 frames: (21-1)//4+1 = 20//4+1 = 5+1 = 6 latent frames (matches diffusers)
182+
vae_scale_factor_temporal = 4
189183

190184
# Also check if height needs different scaling
191185
# 704 -> 88: 704/8 = 88 ✓
@@ -224,60 +218,135 @@ def forward(
224218
# Process input video if provided (video-to-world generation)
225219
# Check multiple possible sources for video input
226220
video = None
221+
logger.info(f"CosmosLatentPreparationStage - Checking for video inputs:")
222+
logger.info(f" batch.video: {getattr(batch, 'video', 'Not found')}")
223+
logger.info(f" batch.pil_image: {getattr(batch, 'pil_image', 'Not found')}")
224+
logger.info(f" batch.preprocessed_image: {getattr(batch, 'preprocessed_image', 'Not found')}")
225+
227226
if hasattr(batch, 'video') and batch.video is not None:
228227
video = batch.video
228+
logger.info("CosmosLatentPreparationStage - Using batch.video")
229229
elif hasattr(batch, 'pil_image') and batch.pil_image is not None:
230+
logger.info(f"CosmosLatentPreparationStage - Found pil_image of type: {type(batch.pil_image)}")
230231
# Convert single image to video format if needed
231232
if isinstance(batch.pil_image, torch.Tensor):
233+
logger.info(f"CosmosLatentPreparationStage - pil_image tensor shape: {batch.pil_image.shape}")
232234
if batch.pil_image.dim() == 4: # [B, C, H, W] -> [B, C, T, H, W]
233235
video = batch.pil_image.unsqueeze(2)
236+
logger.info(f"CosmosLatentPreparationStage - Converted 4D to 5D tensor: {video.shape}")
234237
elif batch.pil_image.dim() == 5: # Already [B, C, T, H, W]
235238
video = batch.pil_image
239+
logger.info(f"CosmosLatentPreparationStage - Using 5D tensor as-is: {video.shape}")
240+
else:
241+
logger.info("CosmosLatentPreparationStage - pil_image is not a tensor, needs preprocessing")
242+
# Following diffusers approach for image-to-video preprocessing
243+
# Convert PIL image to tensor and add temporal dimension
244+
import torchvision.transforms as transforms
245+
246+
# Create transform pipeline similar to diffusers VideoProcessor
247+
transform = transforms.Compose([
248+
transforms.Resize((height, width), antialias=True),
249+
transforms.ToTensor(),
250+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Normalize to [-1, 1]
251+
])
252+
253+
# Apply transform to get [C, H, W] tensor
254+
image_tensor = transform(batch.pil_image)
255+
logger.info(f"CosmosLatentPreparationStage - Transformed PIL to tensor: {image_tensor.shape}")
256+
257+
# Add batch dimension: [C, H, W] -> [B, C, H, W]
258+
image_tensor = image_tensor.unsqueeze(0)
259+
260+
# Add time dimension like diffusers: [B, C, H, W] -> [B, C, T, H, W]
261+
video = image_tensor.unsqueeze(2) # Add time dim at position 2
262+
logger.info(f"CosmosLatentPreparationStage - Added batch and time dims: {video.shape}")
263+
264+
# Move to correct device and ensure compatible dtype for VAE
265+
# Use VAE's parameter dtype to avoid dtype mismatches
266+
if self.vae is not None:
267+
vae_dtype = next(self.vae.parameters()).dtype
268+
else:
269+
vae_dtype = dtype
270+
video = video.to(device=device, dtype=vae_dtype)
271+
logger.info(f"CosmosLatentPreparationStage - Video tensor device: {video.device}, dtype: {video.dtype}")
236272
elif hasattr(batch, 'preprocessed_image') and batch.preprocessed_image is not None:
273+
logger.info(f"CosmosLatentPreparationStage - Found preprocessed_image of type: {type(batch.preprocessed_image)}")
237274
# Convert preprocessed image to video format
238275
if isinstance(batch.preprocessed_image, torch.Tensor):
276+
logger.info(f"CosmosLatentPreparationStage - preprocessed_image tensor shape: {batch.preprocessed_image.shape}")
239277
if batch.preprocessed_image.dim() == 4: # [B, C, H, W] -> [B, C, T, H, W]
240278
video = batch.preprocessed_image.unsqueeze(2)
279+
logger.info(f"CosmosLatentPreparationStage - Converted 4D to 5D tensor: {video.shape}")
241280
elif batch.preprocessed_image.dim() == 5: # Already [B, C, T, H, W]
242281
video = batch.preprocessed_image
282+
logger.info(f"CosmosLatentPreparationStage - Using 5D tensor as-is: {video.shape}")
283+
else:
284+
logger.info("CosmosLatentPreparationStage - No video input sources found")
243285

244286
if video is not None:
245-
video = batch.video
246287
num_cond_frames = video.size(2)
247288

289+
logger.info(f"CosmosLatentPreparationStage - Number of conditioning frames: {num_cond_frames}")
290+
248291
if num_cond_frames >= num_frames:
249292
# Take the last `num_frames` frames for conditioning
250293
num_cond_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
251294
video = video[:, :, -num_frames:]
295+
logger.info(f"CosmosLatentPreparationStage - Using last {num_frames} frames from {num_cond_frames} conditioning frames")
252296
else:
253297
num_cond_latent_frames = (num_cond_frames - 1) // vae_scale_factor_temporal + 1
254298
num_padding_frames = num_frames - num_cond_frames
255299
last_frame = video[:, :, -1:]
256300
padding = last_frame.repeat(1, 1, num_padding_frames, 1, 1)
257301
video = torch.cat([video, padding], dim=2)
302+
logger.info(f"CosmosLatentPreparationStage - Padding {num_cond_frames} conditioning frames with {num_padding_frames} repeated frames")
258303

259304
# Encode video through VAE like diffusers does
260305
if self.vae is not None:
306+
# Move VAE to correct device before encoding
307+
self.vae = self.vae.to(device)
261308
if isinstance(generator, list):
262309
init_latents = []
263310
for i in range(batch_size):
264311
vae_output = self.vae.encode(video[i].unsqueeze(0))
312+
logger.info(f"CosmosLatentPreparationStage - VAE output type: {type(vae_output)}, attributes: {dir(vae_output)}")
313+
314+
# Handle different VAE output types
265315
if hasattr(vae_output, 'latent_dist'):
266316
init_latents.append(vae_output.latent_dist.sample(generator[i] if i < len(generator) else None))
267317
elif hasattr(vae_output, 'latents'):
268318
init_latents.append(vae_output.latents)
319+
elif hasattr(vae_output, 'sample'):
320+
init_latents.append(vae_output.sample(generator[i] if i < len(generator) else None))
321+
elif isinstance(vae_output, torch.Tensor):
322+
# Direct tensor output
323+
init_latents.append(vae_output)
269324
else:
270-
raise AttributeError("Could not access latents of provided encoder_output")
325+
# Try to get the first attribute that looks like latents
326+
attrs = [attr for attr in dir(vae_output) if not attr.startswith('_')]
327+
logger.info(f"CosmosLatentPreparationStage - Available attributes: {attrs}")
328+
raise AttributeError(f"Could not access latents from VAE output. Available attributes: {attrs}")
271329
else:
272330
init_latents_list = []
273331
for vid in video:
274332
vae_output = self.vae.encode(vid.unsqueeze(0))
333+
logger.info(f"CosmosLatentPreparationStage - VAE output type: {type(vae_output)}, attributes: {dir(vae_output)}")
334+
335+
# Handle different VAE output types
275336
if hasattr(vae_output, 'latent_dist'):
276337
init_latents_list.append(vae_output.latent_dist.sample(generator))
277338
elif hasattr(vae_output, 'latents'):
278339
init_latents_list.append(vae_output.latents)
340+
elif hasattr(vae_output, 'sample'):
341+
init_latents_list.append(vae_output.sample(generator))
342+
elif isinstance(vae_output, torch.Tensor):
343+
# Direct tensor output
344+
init_latents_list.append(vae_output)
279345
else:
280-
raise AttributeError("Could not access latents of provided encoder_output")
346+
# Try to get the first attribute that looks like latents
347+
attrs = [attr for attr in dir(vae_output) if not attr.startswith('_')]
348+
logger.info(f"CosmosLatentPreparationStage - Available attributes: {attrs}")
349+
raise AttributeError(f"Could not access latents from VAE output. Available attributes: {attrs}")
281350
init_latents = init_latents_list
282351

283352
init_latents = torch.cat(init_latents, dim=0).to(dtype)
@@ -289,8 +358,12 @@ def forward(
289358
init_latents = (init_latents - latents_mean) / latents_std * self.scheduler.sigma_data
290359

291360
conditioning_latents = init_latents
361+
362+
# Offload VAE to CPU after encoding to save memory
363+
self.vae.to("cpu")
292364
else:
293365
num_cond_latent_frames = 0
366+
logger.info("CosmosLatentPreparationStage - No conditioning frames detected (no video input)")
294367

295368
# Generate or use provided latents
296369
if latents is None:
@@ -301,7 +374,7 @@ def forward(
301374
else:
302375
latents = latents.to(device=device, dtype=dtype)
303376

304-
# Scale latents by sigma_max (Cosmos-specific)
377+
# Scale latents by sigma_max (Cosmos-specific) - exactly like diffusers
305378
latents = latents * self.scheduler.sigma_max
306379

307380
# Create conditioning masks (for video-to-world generation)
@@ -340,8 +413,9 @@ def forward(
340413

341414
# Final verification that shape is correct
342415
logger.info(f"CosmosLatentPreparationStage - FINAL latents shape: {latents.shape}")
343-
# Compare with Diffusers but adjust for actual input dimensions
344-
diffusers_expected = torch.Size([1, 16, 7, 88, 160])
416+
# Compare with Diffusers expected shape for our dimensions
417+
# For 21 frames with temporal_scale=4: (21-1)//4+1 = 6 latent frames
418+
diffusers_expected = torch.Size([1, 16, 6, 88, 160])
345419
if latents.shape != diffusers_expected:
346420
logger.warning(f"CosmosLatentPreparationStage - Shape differs from Diffusers: Expected {diffusers_expected}, got {latents.shape}")
347421
logger.info(f"CosmosLatentPreparationStage - This may be due to different input dimensions (height={height}, width={width})")

test_fastvideo_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def generate_video():
4646
prompt=prompt,
4747
negative_prompt=negative_prompt,
4848
num_frames=21,
49-
input_path=input_image_path,
49+
image_path=input_image_path,
5050
num_inference_steps=35,
5151
guidance_scale=7.0,
5252
seed=42,

0 commit comments

Comments
 (0)