Skip to content

Commit 719af46

Browse files
committed
Fix conditioning mismatch
1 parent a8b8c78 commit 719af46

File tree

4 files changed

+76
-60
lines changed

4 files changed

+76
-60
lines changed

fastvideo/configs/pipelines/cosmos.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,13 @@
1616
def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
1717
"""Postprocess T5 Large text encoder outputs for Cosmos pipeline.
1818
19-
Handles attention masks and sequence padding for the T5 Large model.
19+
Return raw last_hidden_state without truncation/padding.
2020
"""
2121
hidden_state = outputs.last_hidden_state
2222

2323
if hidden_state is None:
2424
raise ValueError("T5 Large outputs missing last_hidden_state")
2525

26-
mask = outputs.attention_mask
27-
28-
# If no attention mask provided, assume all tokens are valid
29-
if mask is None:
30-
batch_size, seq_len = hidden_state.shape[:2]
31-
mask = torch.ones(batch_size, seq_len, device=hidden_state.device, dtype=torch.long)
32-
33-
seq_lens = mask.gt(0).sum(dim=1).long()
34-
3526
# Check for NaN values and provide debugging info
3627
nan_count = torch.isnan(hidden_state).sum()
3728
if nan_count > 0:
@@ -42,16 +33,8 @@ def t5_large_postprocess_text(outputs: BaseEncoderOutput) -> torch.Tensor:
4233
# Replace NaN values with zeros to avoid pipeline failure
4334
hidden_state = hidden_state.masked_fill(torch.isnan(hidden_state), 0.0)
4435

45-
# Create list of tensors with proper sequence lengths
46-
prompt_embeds = [u[:v] for u, v in zip(hidden_state, seq_lens, strict=True)]
47-
48-
# Stack tensors with padding to fixed length (like wan.py implementation)
49-
prompt_embeds_tensor: torch.Tensor = torch.stack([
50-
torch.cat([u, u.new_zeros(512 - u.size(0), u.size(1))])
51-
for u in prompt_embeds
52-
], dim=0)
53-
54-
return prompt_embeds_tensor
36+
# Return raw last_hidden_state (no truncation/padding)
37+
return hidden_state
5538

5639

5740
@dataclass
@@ -134,7 +117,7 @@ class CosmosConfig(PipelineConfig):
134117

135118
# Denoising parameters
136119
embedded_cfg_scale: int = 6
137-
flow_shift: int = 7
120+
flow_shift: float = 1.0 # Changed to 1.0 to match diffusers (no shift transformation)
138121

139122
def __post_init__(self):
140123
self.vae_config.load_encoder = True

fastvideo/pipelines/stages/denoising.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -750,15 +750,20 @@ def forward(
750750
print(f"[FASTVIDEO DEBUG] Step {i}: SKIPPING conditioning frame injection!")
751751
logger.warning(f"Step {i}: Missing conditioning data - cond_indicator: {hasattr(batch, 'cond_indicator')}, conditioning_latents: {conditioning_latents is not None}")
752752

753-
# cond_latent = cond_latent.to(target_dtype)
753+
# Convert cond_latent to target dtype BEFORE debug logging to match Diffusers
754+
cond_latent = cond_latent.to(target_dtype)
754755

755-
# # Apply conditional timestep processing like diffusers (lines 720-721)
756-
# cond_timestep = timestep
757-
# if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
758-
# cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
759-
# cond_timestep = cond_timestep.to(target_dtype)
760-
# if i < 3:
761-
# logger.info(f"Step {i}: Applied conditional timestep - t_conditioning: {t_conditioning:.6f}, cond_timestep sum: {cond_timestep.float().sum().item():.6f}")
756+
# Apply conditional timestep processing like Diffusers (lines 792-793)
757+
cond_timestep = timestep
758+
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
759+
# Exactly match Diffusers: cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
760+
# First get t_conditioning (sigma_conditioning value from Diffusers)
761+
sigma_conditioning = 0.0001 # Same as Diffusers default
762+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
763+
cond_timestep = batch.cond_indicator * t_conditioning + (1 - batch.cond_indicator) * timestep
764+
cond_timestep = cond_timestep.to(target_dtype)
765+
if i < 3:
766+
logger.info(f"Step {i}: Applied conditional timestep - t_conditioning: {t_conditioning:.6f}, cond_timestep sum: {cond_timestep.float().sum().item():.6f}")
762767

763768
with set_forward_context(
764769
current_timestep=i,
@@ -767,7 +772,8 @@ def forward(
767772
):
768773
# Use conditioning masks from CosmosLatentPreparationStage
769774
condition_mask = batch.cond_mask.to(target_dtype) if hasattr(batch, 'cond_mask') else None
770-
padding_mask = torch.zeros(1, 1, cond_latent.shape[3], cond_latent.shape[4],
775+
# Padding mask should match original image dimensions like Diffusers (704, 1280)
776+
padding_mask = torch.zeros(1, 1, batch.height, batch.width,
771777
device=cond_latent.device, dtype=target_dtype)
772778

773779
# Fallback if masks not available
@@ -786,10 +792,34 @@ def forward(
786792
logger.info(f" condition_mask shape: {condition_mask.shape if condition_mask is not None else None}")
787793
logger.info(f" padding_mask shape: {padding_mask.shape}")
788794

795+
# Log detailed transformer inputs for comparison with Diffusers
796+
if i < 3:
797+
print(f"FASTVIDEO TRANSFORMER INPUTS (step {i}):")
798+
print(f" hidden_states: shape={cond_latent.shape}, sum={cond_latent.float().sum().item():.6f}, mean={cond_latent.float().mean().item():.6f}")
799+
print(f" timestep: shape={cond_timestep.shape}, sum={cond_timestep.float().sum().item():.6f}, values={cond_timestep.flatten()[:5].float()}")
800+
print(f" encoder_hidden_states: shape={batch.prompt_embeds[0].shape}, sum={batch.prompt_embeds[0].float().sum().item():.6f}")
801+
print(f" condition_mask: shape={condition_mask.shape if condition_mask is not None else None}, sum={condition_mask.float().sum().item() if condition_mask is not None else None}")
802+
print(f" padding_mask: shape={padding_mask.shape}, sum={padding_mask.float().sum().item():.6f}")
803+
print(f" fps: {24}, target_dtype: {target_dtype}")
804+
print(f" DTYPES: hidden_states={cond_latent.dtype}, timestep={cond_timestep.dtype}, encoder_hidden_states={batch.prompt_embeds[0].dtype}")
805+
print(f" hidden_states first 5 values: {cond_latent.flatten()[:5].float()}")
806+
print(f" encoder_hidden_states first 5 values: {batch.prompt_embeds[0].flatten()[:5].float()}")
807+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
808+
f.write(f"FASTVIDEO TRANSFORMER INPUTS (step {i}):\n")
809+
f.write(f" hidden_states: shape={cond_latent.shape}, sum={cond_latent.float().sum().item():.6f}, mean={cond_latent.float().mean().item():.6f}\n")
810+
f.write(f" timestep: shape={cond_timestep.shape}, sum={cond_timestep.float().sum().item():.6f}, values={cond_timestep.flatten()[:5].float()}\n")
811+
f.write(f" encoder_hidden_states: shape={batch.prompt_embeds[0].shape}, sum={batch.prompt_embeds[0].float().sum().item():.6f}\n")
812+
f.write(f" condition_mask: shape={condition_mask.shape if condition_mask is not None else None}, sum={condition_mask.float().sum().item() if condition_mask is not None else None}\n")
813+
f.write(f" padding_mask: shape={padding_mask.shape}, sum={padding_mask.float().sum().item():.6f}\n")
814+
f.write(f" fps: {24}, target_dtype: {target_dtype}\n")
815+
f.write(f" DTYPES: hidden_states={cond_latent.dtype}, timestep={cond_timestep.dtype}, encoder_hidden_states={batch.prompt_embeds[0].dtype}\n")
816+
f.write(f" hidden_states first 5 values: {cond_latent.flatten()[:5].float()}\n")
817+
f.write(f" encoder_hidden_states first 5 values: {batch.prompt_embeds[0].flatten()[:5].float()}\n")
818+
789819
print(f"[FASTVIDEO DENOISING] About to call transformer with hidden_states sum = {cond_latent.float().sum().item()}")
790820
noise_pred = self.transformer(
791-
hidden_states=cond_latent.to(target_dtype),
792-
timestep=timestep.to(target_dtype),
821+
hidden_states=cond_latent, # Already converted to target_dtype above
822+
timestep=cond_timestep.to(target_dtype),
793823
encoder_hidden_states=batch.prompt_embeds[0].to(target_dtype),
794824
fps=24, # TODO: get fps from batch or config
795825
condition_mask=condition_mask,
@@ -805,11 +835,7 @@ def forward(
805835
# Apply preconditioning exactly like diffusers
806836
cond_pred = (c_skip * latents + c_out * noise_pred.float()).to(target_dtype)
807837

808-
# Apply conditional indicator masking (from CosmosLatentPreparationStage)
809-
if hasattr(batch, 'cond_indicator') and batch.cond_indicator is not None:
810-
conditioning_latents = batch.conditioning_latents if batch.conditioning_latents is not None else torch.zeros_like(latents)
811-
cond_pred = batch.cond_indicator * conditioning_latents + (1 - batch.cond_indicator) * cond_pred
812-
838+
# NOTE: Conditioning frame injection is applied to cond_latent BEFORE transformer call (line 746), not after
813839
# Classifier-free guidance
814840
if batch.do_classifier_free_guidance and batch.negative_prompt_embeds is not None:
815841
# Unconditional pass - match diffusers logic (lines 755-759)
@@ -830,9 +856,17 @@ def forward(
830856
logger.info(f" negative_prompt_embeds shape: {batch.negative_prompt_embeds[0].shape}")
831857
# sum: {uncond_timestep.float().sum().item():.6f}")
832858

859+
# Apply same conditional timestep processing for unconditional pass
860+
uncond_timestep = timestep
861+
if hasattr(batch, 'uncond_indicator') and batch.uncond_indicator is not None:
862+
sigma_conditioning = 0.0001 # Same as Diffusers default
863+
t_conditioning = sigma_conditioning / (sigma_conditioning + 1)
864+
uncond_timestep = batch.uncond_indicator * t_conditioning + (1 - batch.uncond_indicator) * timestep
865+
uncond_timestep = uncond_timestep.to(target_dtype)
866+
833867
noise_pred_uncond = self.transformer(
834868
hidden_states=uncond_latent.to(target_dtype),
835-
timestep=timestep.to(target_dtype),
869+
timestep=uncond_timestep.to(target_dtype),
836870
encoder_hidden_states=batch.negative_prompt_embeds[0].to(target_dtype),
837871
fps=24, # TODO: get fps from batch or config
838872
condition_mask=uncond_condition_mask,

fastvideo/pipelines/stages/latent_preparation.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -240,36 +240,30 @@ def forward(
240240
logger.info(f"CosmosLatentPreparationStage - Using 5D tensor as-is: {video.shape}")
241241
else:
242242
logger.info("CosmosLatentPreparationStage - pil_image is not a tensor, needs preprocessing")
243-
# Following diffusers approach for image-to-video preprocessing
244-
# Convert PIL image to tensor and add temporal dimension
245-
import torchvision.transforms as transforms
243+
# Use same preprocessing as diffusers VideoProcessor
244+
from diffusers.video_processor import VideoProcessor
246245

247-
# Create transform pipeline similar to diffusers VideoProcessor
248-
transform = transforms.Compose([
249-
transforms.Resize((height, width), antialias=True),
250-
transforms.ToTensor(),
251-
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # Normalize to [-1, 1]
252-
])
246+
# Create VideoProcessor with same parameters as diffusers Cosmos pipeline
247+
vae_scale_factor_spatial = 8 # Same as diffusers
248+
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
253249

254-
# Apply transform to get [C, H, W] tensor
255-
image_tensor = transform(batch.pil_image)
256-
logger.info(f"CosmosLatentPreparationStage - Transformed PIL to tensor: {image_tensor.shape}")
250+
# Use exact same method as diffusers: preprocess image then unsqueeze for time dimension
251+
processed_image = video_processor.preprocess(batch.pil_image, height, width)
252+
logger.info(f"CosmosLatentPreparationStage - VideoProcessor preprocess result: shape={processed_image.shape}, dtype={processed_image.dtype}, device={processed_image.device}")
257253

258-
# Add batch dimension: [C, H, W] -> [B, C, H, W]
259-
image_tensor = image_tensor.unsqueeze(0)
254+
# Add time dimension exactly like diffusers: unsqueeze(2)
255+
video = processed_image.unsqueeze(2)
256+
logger.info(f"CosmosLatentPreparationStage - After unsqueeze(2): shape={video.shape}, dtype={video.dtype}, device={video.device}")
260257

261-
# Add time dimension like diffusers: [B, C, H, W] -> [B, C, T, H, W]
262-
video = image_tensor.unsqueeze(2) # Add time dim at position 2
263-
logger.info(f"CosmosLatentPreparationStage - Added batch and time dims: {video.shape}")
264-
265-
# Move to correct device and ensure compatible dtype for VAE
266-
# Use VAE's parameter dtype to avoid dtype mismatches
258+
# Exactly match diffusers' device/dtype handling: to(device=device, dtype=vae_dtype)
259+
# Get VAE dtype exactly like diffusers
267260
if self.vae is not None:
268-
vae_dtype = next(self.vae.parameters()).dtype
261+
vae_dtype = next(self.vae.parameters()).dtype # Get VAE's parameter dtype
269262
else:
270263
vae_dtype = dtype
264+
271265
video = video.to(device=device, dtype=vae_dtype)
272-
logger.info(f"CosmosLatentPreparationStage - Video tensor device: {video.device}, dtype: {video.dtype}")
266+
logger.info(f"CosmosLatentPreparationStage - After to(device, dtype): shape={video.shape}, dtype={video.dtype}, device={video.device}, vae_dtype={vae_dtype}")
273267
elif hasattr(batch, 'preprocessed_image') and batch.preprocessed_image is not None:
274268
logger.info(f"CosmosLatentPreparationStage - Found preprocessed_image of type: {type(batch.preprocessed_image)}")
275269
# Convert preprocessed image to video format

fastvideo/pipelines/stages/text_encoding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ def forward(
9090
# Write to output file
9191
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
9292
f.write(f"TextEncodingStage: prompt_embeds sum = {sum_value:.6f}\n")
93+
94+
lengths = attention_mask.sum(dim=1).cpu()
95+
for i, length in enumerate(lengths):
96+
prompt_embeds[i, length:] = 0
97+
9398
batch.prompt_embeds.append(prompt_embeds)
9499
if batch.prompt_attention_mask is not None:
95100
batch.prompt_attention_mask.append(attention_mask)

0 commit comments

Comments
 (0)