@@ -157,7 +157,8 @@ def forward(
157157 batch_size *= batch .num_videos_per_prompt
158158
159159 # Get required parameters
160- dtype = batch .prompt_embeds [0 ].dtype
160+ # Force float32 for latent preparation to match diffusers behavior
161+ dtype = torch .float32 # Override to match diffusers
161162 device = get_local_torch_device ()
162163 generator = batch .generator
163164 latents = batch .latents
@@ -311,13 +312,14 @@ def forward(
311312 vae_output = self .vae .encode (video [i ].unsqueeze (0 ))
312313 logger .info (f"CosmosLatentPreparationStage - VAE output type: { type (vae_output )} , attributes: { dir (vae_output )} " )
313314
314- # Handle different VAE output types
315+ # Handle different VAE output types with fresh generator for VAE operations
316+ vae_generator = torch .Generator (device = "cpu" ).manual_seed (100 )
315317 if hasattr (vae_output , 'latent_dist' ):
316- init_latents .append (vae_output .latent_dist .sample (generator [ i ] if i < len ( generator ) else None ))
318+ init_latents .append (vae_output .latent_dist .sample (vae_generator ))
317319 elif hasattr (vae_output , 'latents' ):
318320 init_latents .append (vae_output .latents )
319321 elif hasattr (vae_output , 'sample' ):
320- init_latents .append (vae_output .sample (generator [ i ] if i < len ( generator ) else None ))
322+ init_latents .append (vae_output .sample (vae_generator ))
321323 elif isinstance (vae_output , torch .Tensor ):
322324 # Direct tensor output
323325 init_latents .append (vae_output )
@@ -332,13 +334,14 @@ def forward(
332334 vae_output = self .vae .encode (vid .unsqueeze (0 ))
333335 logger .info (f"CosmosLatentPreparationStage - VAE output type: { type (vae_output )} , attributes: { dir (vae_output )} " )
334336
335- # Handle different VAE output types
337+ # Handle different VAE output types with fresh generator for VAE operations
338+ vae_generator = torch .Generator (device = "cpu" ).manual_seed (100 )
336339 if hasattr (vae_output , 'latent_dist' ):
337- init_latents_list .append (vae_output .latent_dist .sample (generator ))
340+ init_latents_list .append (vae_output .latent_dist .sample (vae_generator ))
338341 elif hasattr (vae_output , 'latents' ):
339342 init_latents_list .append (vae_output .latents )
340343 elif hasattr (vae_output , 'sample' ):
341- init_latents_list .append (vae_output .sample (generator ))
344+ init_latents_list .append (vae_output .sample (vae_generator ))
342345 elif isinstance (vae_output , torch .Tensor ):
343346 # Direct tensor output
344347 init_latents_list .append (vae_output )
@@ -367,11 +370,15 @@ def forward(
367370
368371 # Generate or use provided latents
369372 if latents is None :
373+ print (f"[FASTVIDEO DEBUG] Creating latents with randn_tensor, shape={ shape } , device={ device } , dtype={ dtype } " )
374+ # Use float32 for randn_tensor to match diffusers exactly
370375 latents = randn_tensor (shape ,
371- generator = generator ,
376+ generator = torch . Generator ( device = "cpu" ). manual_seed ( 200 ) ,
372377 device = device ,
373378 dtype = dtype )
379+ print (f"[FASTVIDEO DEBUG] Created latents with sum={ latents .float ().sum ().item ()} , device={ latents .device } , dtype={ latents .dtype } " )
374380 else :
381+ print (f"[FASTVIDEO DEBUG] Using provided latents, shape={ latents .shape } " )
375382 latents = latents .to (device = device , dtype = dtype )
376383
377384 # Scale latents by sigma_max (Cosmos-specific) - exactly like diffusers
0 commit comments