@@ -176,16 +176,10 @@ def forward(
176176 # 704/88 = 8, 1280/160 = 8, so spatial_scale = 8
177177 vae_scale_factor_spatial = 8
178178
179- # For temporal: Need 93 frames -> 7 latent frames
180- # Using formula: (num_frames - 1) // temporal_scale + 1 = 7
181- # So: (93 - 1) // temporal_scale + 1 = 7
182- # 92 // temporal_scale + 1 = 7
183- # 92 // temporal_scale = 6
184- # temporal_scale = 92 // 6 = 15.33 -> try 13
185- # Check: (93-1)//13 + 1 = 92//13 + 1 = 7 + 1 = 8 (too high)
186- # Try 12: (93-1)//12 + 1 = 92//12 + 1 = 7 + 1 = 8 (too high)
187- # Try 15: (93-1)//15 + 1 = 92//15 + 1 = 6 + 1 = 7 ✓
188- vae_scale_factor_temporal = 15
179+ # For temporal: Use the same scale factor as diffusers Cosmos pipeline
180+ # Diffusers uses vae_scale_factor_temporal = 4 as default
181+ # For 21 frames: (21-1)//4+1 = 20//4+1 = 5+1 = 6 latent frames (matches diffusers)
182+ vae_scale_factor_temporal = 4
189183
190184 # Also check if height needs different scaling
191185 # 704 -> 88: 704/8 = 88 ✓
@@ -224,60 +218,135 @@ def forward(
224218 # Process input video if provided (video-to-world generation)
225219 # Check multiple possible sources for video input
226220 video = None
221+ logger .info (f"CosmosLatentPreparationStage - Checking for video inputs:" )
222+ logger .info (f" batch.video: { getattr (batch , 'video' , 'Not found' )} " )
223+ logger .info (f" batch.pil_image: { getattr (batch , 'pil_image' , 'Not found' )} " )
224+ logger .info (f" batch.preprocessed_image: { getattr (batch , 'preprocessed_image' , 'Not found' )} " )
225+
227226 if hasattr (batch , 'video' ) and batch .video is not None :
228227 video = batch .video
228+ logger .info ("CosmosLatentPreparationStage - Using batch.video" )
229229 elif hasattr (batch , 'pil_image' ) and batch .pil_image is not None :
230+ logger .info (f"CosmosLatentPreparationStage - Found pil_image of type: { type (batch .pil_image )} " )
230231 # Convert single image to video format if needed
231232 if isinstance (batch .pil_image , torch .Tensor ):
233+ logger .info (f"CosmosLatentPreparationStage - pil_image tensor shape: { batch .pil_image .shape } " )
232234 if batch .pil_image .dim () == 4 : # [B, C, H, W] -> [B, C, T, H, W]
233235 video = batch .pil_image .unsqueeze (2 )
236+ logger .info (f"CosmosLatentPreparationStage - Converted 4D to 5D tensor: { video .shape } " )
234237 elif batch .pil_image .dim () == 5 : # Already [B, C, T, H, W]
235238 video = batch .pil_image
239+ logger .info (f"CosmosLatentPreparationStage - Using 5D tensor as-is: { video .shape } " )
240+ else :
241+ logger .info ("CosmosLatentPreparationStage - pil_image is not a tensor, needs preprocessing" )
242+ # Following diffusers approach for image-to-video preprocessing
243+ # Convert PIL image to tensor and add temporal dimension
244+ import torchvision .transforms as transforms
245+
246+ # Create transform pipeline similar to diffusers VideoProcessor
247+ transform = transforms .Compose ([
248+ transforms .Resize ((height , width ), antialias = True ),
249+ transforms .ToTensor (),
250+ transforms .Normalize ([0.5 , 0.5 , 0.5 ], [0.5 , 0.5 , 0.5 ]) # Normalize to [-1, 1]
251+ ])
252+
253+ # Apply transform to get [C, H, W] tensor
254+ image_tensor = transform (batch .pil_image )
255+ logger .info (f"CosmosLatentPreparationStage - Transformed PIL to tensor: { image_tensor .shape } " )
256+
257+ # Add batch dimension: [C, H, W] -> [B, C, H, W]
258+ image_tensor = image_tensor .unsqueeze (0 )
259+
260+ # Add time dimension like diffusers: [B, C, H, W] -> [B, C, T, H, W]
261+ video = image_tensor .unsqueeze (2 ) # Add time dim at position 2
262+ logger .info (f"CosmosLatentPreparationStage - Added batch and time dims: { video .shape } " )
263+
264+ # Move to correct device and ensure compatible dtype for VAE
265+ # Use VAE's parameter dtype to avoid dtype mismatches
266+ if self .vae is not None :
267+ vae_dtype = next (self .vae .parameters ()).dtype
268+ else :
269+ vae_dtype = dtype
270+ video = video .to (device = device , dtype = vae_dtype )
271+ logger .info (f"CosmosLatentPreparationStage - Video tensor device: { video .device } , dtype: { video .dtype } " )
236272 elif hasattr (batch , 'preprocessed_image' ) and batch .preprocessed_image is not None :
273+ logger .info (f"CosmosLatentPreparationStage - Found preprocessed_image of type: { type (batch .preprocessed_image )} " )
237274 # Convert preprocessed image to video format
238275 if isinstance (batch .preprocessed_image , torch .Tensor ):
276+ logger .info (f"CosmosLatentPreparationStage - preprocessed_image tensor shape: { batch .preprocessed_image .shape } " )
239277 if batch .preprocessed_image .dim () == 4 : # [B, C, H, W] -> [B, C, T, H, W]
240278 video = batch .preprocessed_image .unsqueeze (2 )
279+ logger .info (f"CosmosLatentPreparationStage - Converted 4D to 5D tensor: { video .shape } " )
241280 elif batch .preprocessed_image .dim () == 5 : # Already [B, C, T, H, W]
242281 video = batch .preprocessed_image
282+ logger .info (f"CosmosLatentPreparationStage - Using 5D tensor as-is: { video .shape } " )
283+ else :
284+ logger .info ("CosmosLatentPreparationStage - No video input sources found" )
243285
244286 if video is not None :
245- video = batch .video
246287 num_cond_frames = video .size (2 )
247288
289+ logger .info (f"CosmosLatentPreparationStage - Number of conditioning frames: { num_cond_frames } " )
290+
248291 if num_cond_frames >= num_frames :
249292 # Take the last `num_frames` frames for conditioning
250293 num_cond_latent_frames = (num_frames - 1 ) // vae_scale_factor_temporal + 1
251294 video = video [:, :, - num_frames :]
295+ logger .info (f"CosmosLatentPreparationStage - Using last { num_frames } frames from { num_cond_frames } conditioning frames" )
252296 else :
253297 num_cond_latent_frames = (num_cond_frames - 1 ) // vae_scale_factor_temporal + 1
254298 num_padding_frames = num_frames - num_cond_frames
255299 last_frame = video [:, :, - 1 :]
256300 padding = last_frame .repeat (1 , 1 , num_padding_frames , 1 , 1 )
257301 video = torch .cat ([video , padding ], dim = 2 )
302+ logger .info (f"CosmosLatentPreparationStage - Padding { num_cond_frames } conditioning frames with { num_padding_frames } repeated frames" )
258303
259304 # Encode video through VAE like diffusers does
260305 if self .vae is not None :
306+ # Move VAE to correct device before encoding
307+ self .vae = self .vae .to (device )
261308 if isinstance (generator , list ):
262309 init_latents = []
263310 for i in range (batch_size ):
264311 vae_output = self .vae .encode (video [i ].unsqueeze (0 ))
312+ logger .info (f"CosmosLatentPreparationStage - VAE output type: { type (vae_output )} , attributes: { dir (vae_output )} " )
313+
314+ # Handle different VAE output types
265315 if hasattr (vae_output , 'latent_dist' ):
266316 init_latents .append (vae_output .latent_dist .sample (generator [i ] if i < len (generator ) else None ))
267317 elif hasattr (vae_output , 'latents' ):
268318 init_latents .append (vae_output .latents )
319+ elif hasattr (vae_output , 'sample' ):
320+ init_latents .append (vae_output .sample (generator [i ] if i < len (generator ) else None ))
321+ elif isinstance (vae_output , torch .Tensor ):
322+ # Direct tensor output
323+ init_latents .append (vae_output )
269324 else :
270- raise AttributeError ("Could not access latents of provided encoder_output" )
325+ # Try to get the first attribute that looks like latents
326+ attrs = [attr for attr in dir (vae_output ) if not attr .startswith ('_' )]
327+ logger .info (f"CosmosLatentPreparationStage - Available attributes: { attrs } " )
328+ raise AttributeError (f"Could not access latents from VAE output. Available attributes: { attrs } " )
271329 else :
272330 init_latents_list = []
273331 for vid in video :
274332 vae_output = self .vae .encode (vid .unsqueeze (0 ))
333+ logger .info (f"CosmosLatentPreparationStage - VAE output type: { type (vae_output )} , attributes: { dir (vae_output )} " )
334+
335+ # Handle different VAE output types
275336 if hasattr (vae_output , 'latent_dist' ):
276337 init_latents_list .append (vae_output .latent_dist .sample (generator ))
277338 elif hasattr (vae_output , 'latents' ):
278339 init_latents_list .append (vae_output .latents )
340+ elif hasattr (vae_output , 'sample' ):
341+ init_latents_list .append (vae_output .sample (generator ))
342+ elif isinstance (vae_output , torch .Tensor ):
343+ # Direct tensor output
344+ init_latents_list .append (vae_output )
279345 else :
280- raise AttributeError ("Could not access latents of provided encoder_output" )
346+ # Try to get the first attribute that looks like latents
347+ attrs = [attr for attr in dir (vae_output ) if not attr .startswith ('_' )]
348+ logger .info (f"CosmosLatentPreparationStage - Available attributes: { attrs } " )
349+ raise AttributeError (f"Could not access latents from VAE output. Available attributes: { attrs } " )
281350 init_latents = init_latents_list
282351
283352 init_latents = torch .cat (init_latents , dim = 0 ).to (dtype )
@@ -289,8 +358,12 @@ def forward(
289358 init_latents = (init_latents - latents_mean ) / latents_std * self .scheduler .sigma_data
290359
291360 conditioning_latents = init_latents
361+
362+ # Offload VAE to CPU after encoding to save memory
363+ self .vae .to ("cpu" )
292364 else :
293365 num_cond_latent_frames = 0
366+ logger .info ("CosmosLatentPreparationStage - No conditioning frames detected (no video input)" )
294367
295368 # Generate or use provided latents
296369 if latents is None :
@@ -301,7 +374,7 @@ def forward(
301374 else :
302375 latents = latents .to (device = device , dtype = dtype )
303376
304- # Scale latents by sigma_max (Cosmos-specific)
377+ # Scale latents by sigma_max (Cosmos-specific) - exactly like diffusers
305378 latents = latents * self .scheduler .sigma_max
306379
307380 # Create conditioning masks (for video-to-world generation)
@@ -340,8 +413,9 @@ def forward(
340413
341414 # Final verification that shape is correct
342415 logger .info (f"CosmosLatentPreparationStage - FINAL latents shape: { latents .shape } " )
343- # Compare with Diffusers but adjust for actual input dimensions
344- diffusers_expected = torch .Size ([1 , 16 , 7 , 88 , 160 ])
416+ # Compare with Diffusers expected shape for our dimensions
417+ # For 21 frames with temporal_scale=4: (21-1)//4+1 = 6 latent frames
418+ diffusers_expected = torch .Size ([1 , 16 , 6 , 88 , 160 ])
345419 if latents .shape != diffusers_expected :
346420 logger .warning (f"CosmosLatentPreparationStage - Shape differs from Diffusers: Expected { diffusers_expected } , got { latents .shape } " )
347421 logger .info (f"CosmosLatentPreparationStage - This may be due to different input dimensions (height={ height } , width={ width } )" )
0 commit comments