Skip to content

Commit adf49de

Browse files
committed
Fix latent prep mismatch
1 parent ffb8b0a commit adf49de

File tree

2 files changed

+17
-10
lines changed

2 files changed

+17
-10
lines changed

fastvideo/pipelines/stages/latent_preparation.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def forward(
157157
batch_size *= batch.num_videos_per_prompt
158158

159159
# Get required parameters
160-
dtype = batch.prompt_embeds[0].dtype
160+
# Force float32 for latent preparation to match diffusers behavior
161+
dtype = torch.float32 # Override to match diffusers
161162
device = get_local_torch_device()
162163
generator = batch.generator
163164
latents = batch.latents
@@ -311,13 +312,14 @@ def forward(
311312
vae_output = self.vae.encode(video[i].unsqueeze(0))
312313
logger.info(f"CosmosLatentPreparationStage - VAE output type: {type(vae_output)}, attributes: {dir(vae_output)}")
313314

314-
# Handle different VAE output types
315+
# Handle different VAE output types with fresh generator for VAE operations
316+
vae_generator = torch.Generator(device="cpu").manual_seed(100)
315317
if hasattr(vae_output, 'latent_dist'):
316-
init_latents.append(vae_output.latent_dist.sample(generator[i] if i < len(generator) else None))
318+
init_latents.append(vae_output.latent_dist.sample(vae_generator))
317319
elif hasattr(vae_output, 'latents'):
318320
init_latents.append(vae_output.latents)
319321
elif hasattr(vae_output, 'sample'):
320-
init_latents.append(vae_output.sample(generator[i] if i < len(generator) else None))
322+
init_latents.append(vae_output.sample(vae_generator))
321323
elif isinstance(vae_output, torch.Tensor):
322324
# Direct tensor output
323325
init_latents.append(vae_output)
@@ -332,13 +334,14 @@ def forward(
332334
vae_output = self.vae.encode(vid.unsqueeze(0))
333335
logger.info(f"CosmosLatentPreparationStage - VAE output type: {type(vae_output)}, attributes: {dir(vae_output)}")
334336

335-
# Handle different VAE output types
337+
# Handle different VAE output types with fresh generator for VAE operations
338+
vae_generator = torch.Generator(device="cpu").manual_seed(100)
336339
if hasattr(vae_output, 'latent_dist'):
337-
init_latents_list.append(vae_output.latent_dist.sample(generator))
340+
init_latents_list.append(vae_output.latent_dist.sample(vae_generator))
338341
elif hasattr(vae_output, 'latents'):
339342
init_latents_list.append(vae_output.latents)
340343
elif hasattr(vae_output, 'sample'):
341-
init_latents_list.append(vae_output.sample(generator))
344+
init_latents_list.append(vae_output.sample(vae_generator))
342345
elif isinstance(vae_output, torch.Tensor):
343346
# Direct tensor output
344347
init_latents_list.append(vae_output)
@@ -367,11 +370,15 @@ def forward(
367370

368371
# Generate or use provided latents
369372
if latents is None:
373+
print(f"[FASTVIDEO DEBUG] Creating latents with randn_tensor, shape={shape}, device={device}, dtype={dtype}")
374+
# Use float32 for randn_tensor to match diffusers exactly
370375
latents = randn_tensor(shape,
371-
generator=generator,
376+
generator=torch.Generator(device="cpu").manual_seed(200),
372377
device=device,
373378
dtype=dtype)
379+
print(f"[FASTVIDEO DEBUG] Created latents with sum={latents.float().sum().item()}, device={latents.device}, dtype={latents.dtype}")
374380
else:
381+
print(f"[FASTVIDEO DEBUG] Using provided latents, shape={latents.shape}")
375382
latents = latents.to(device=device, dtype=dtype)
376383

377384
# Scale latents by sigma_max (Cosmos-specific) - exactly like diffusers

test_fastvideo_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def generate_video():
3333
print("Creating FastVideo generator...")
3434
generator = VideoGenerator.from_pretrained(
3535
model_path="nvidia/Cosmos-Predict2-2B-Video2World",
36-
num_gpus=2,
36+
num_gpus=1,
3737
)
3838

3939
print("Generator created successfully")
@@ -49,7 +49,7 @@ def generate_video():
4949
image_path=input_image_path,
5050
num_inference_steps=35,
5151
guidance_scale=7.0,
52-
seed=42,
52+
seed=1,
5353
save_video=True,
5454
output_path=output_path
5555
)

0 commit comments

Comments
 (0)