Skip to content

Commit f747b40

Browse files
committed
Addt. logging
1 parent 11aaaae commit f747b40

File tree

5 files changed

+217
-65
lines changed

5 files changed

+217
-65
lines changed

fastvideo/models/dits/cosmos.py

Lines changed: 118 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,14 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7676
def forward(self, hidden_states: torch.Tensor,
7777
timestep: torch.LongTensor) -> torch.Tensor:
7878
timesteps_proj = self.time_proj(timestep).type_as(hidden_states)
79+
print(f"[FASTVIDEO] timesteps_proj before norm: {timesteps_proj.float().sum().item()}")
80+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
81+
f.write(f"[FASTVIDEO] timesteps_proj before norm: {timesteps_proj.float().sum().item()}\n")
7982
temb = self.t_embedder(timesteps_proj)
8083
embedded_timestep = self.norm(timesteps_proj)
84+
print(f"[FASTVIDEO] embedded_timestep after norm: {embedded_timestep.float().sum().item()}")
85+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
86+
f.write(f"[FASTVIDEO] embedded_timestep after norm: {embedded_timestep.float().sum().item()}\n")
8187
return temb, embedded_timestep
8288

8389

@@ -133,10 +139,7 @@ def __init__(self,
133139
else:
134140
self.linear_1 = nn.Linear(in_features, hidden_features, bias=False)
135141

136-
self.linear_2 = nn.Linear(
137-
hidden_features if hidden_features is not None else in_features,
138-
3 * in_features,
139-
bias=False)
142+
self.linear_2 = nn.Linear(hidden_features, 3 * in_features, bias=False)
140143

141144
def forward(
142145
self,
@@ -197,10 +200,16 @@ def forward(self,
197200
if encoder_hidden_states is None:
198201
encoder_hidden_states = hidden_states
199202

203+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
204+
f.write(f"[FASTVIDEO SELF-ATTN] INIT hidden_states: Q={hidden_states.float().sum().item()}\n")
205+
200206
# Get QKV
201207
query = self.to_q(hidden_states)
202208
key = self.to_k(encoder_hidden_states)
203209
value = self.to_v(encoder_hidden_states)
210+
print(f"[FASTVIDEO SELF-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}")
211+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
212+
f.write(f"[FASTVIDEO SELF-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}\n")
204213

205214
# Reshape for multi-head attention
206215
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
@@ -209,9 +218,9 @@ def forward(self,
209218

210219
# Apply normalization
211220
if self.norm_q is not None:
212-
query = self.norm_q.forward_native(query)
221+
query = self.norm_q(query)
213222
if self.norm_k is not None:
214-
key = self.norm_k.forward_native(key)
223+
key = self.norm_k(key)
215224

216225
# Apply RoPE if provided
217226
if image_rotary_emb is not None:
@@ -224,12 +233,28 @@ def forward(self,
224233
use_real=True,
225234
use_real_unbind_dim=-2)
226235

236+
# Prepare for GQA (Grouped Query Attention)
237+
if torch.onnx.is_in_onnx_export():
238+
query_idx = torch.tensor(query.size(3), device=query.device)
239+
key_idx = torch.tensor(key.size(3), device=key.device)
240+
value_idx = torch.tensor(value.size(3), device=value.device)
241+
else:
242+
query_idx = query.size(3)
243+
key_idx = key.size(3)
244+
value_idx = value.size(3)
245+
key = key.repeat_interleave(query_idx // key_idx, dim=3)
246+
value = value.repeat_interleave(query_idx // value_idx, dim=3)
247+
227248
# Attention computation
228249
# Use standard PyTorch scaled dot product attention
229250
attn_output = torch.nn.functional.scaled_dot_product_attention(
230-
query, key, value, attn_mask=attention_mask, dropout_p=0.0
251+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
231252
)
232253
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
254+
print(f"[FASTVIDEO TRANSFORMER] hidden_states: {attn_output.float().sum().item()}")
255+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
256+
f.write(f"[FASTVIDEO TRANSFORMER] hidden_states: {attn_output.float().sum().item()}\n")
257+
f.write(f"self.to_out: {self.to_out}")
233258

234259
# Output projection
235260
attn_output = self.to_out(attn_output)
@@ -275,6 +300,9 @@ def forward(self,
275300
query = self.to_q(hidden_states)
276301
key = self.to_k(encoder_hidden_states)
277302
value = self.to_v(encoder_hidden_states)
303+
# print(f"[FASTVIDEO CROSS-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}")
304+
# with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
305+
# f.write(f"[FASTVIDEO CROSS-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}\n")
278306

279307
# Reshape for multi-head attention
280308
# Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
@@ -284,13 +312,25 @@ def forward(self,
284312

285313
# Apply normalization
286314
if self.norm_q is not None:
287-
query = self.norm_q.forward_native(query)
315+
query = self.norm_q(query)
288316
if self.norm_k is not None:
289-
key = self.norm_k.forward_native(key)
317+
key = self.norm_k(key)
318+
319+
# Prepare for GQA (Grouped Query Attention)
320+
if torch.onnx.is_in_onnx_export():
321+
query_idx = torch.tensor(query.size(3), device=query.device)
322+
key_idx = torch.tensor(key.size(3), device=key.device)
323+
value_idx = torch.tensor(value.size(3), device=value.device)
324+
else:
325+
query_idx = query.size(3)
326+
key_idx = key.size(3)
327+
value_idx = value.size(3)
328+
key = key.repeat_interleave(query_idx // key_idx, dim=3)
329+
value = value.repeat_interleave(query_idx // value_idx, dim=3)
290330

291331
# Attention computation
292332
attn_output = torch.nn.functional.scaled_dot_product_attention(
293-
query, key, value, attn_mask=attention_mask, dropout_p=0.0
333+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
294334
)
295335
attn_output = attn_output.transpose(1, 2).flatten(2, 3).type_as(query)
296336

@@ -317,6 +357,11 @@ def __init__(
317357

318358
hidden_size = num_attention_heads * attention_head_dim
319359

360+
print(f"[FASTVIDEO TRANSFORMER] hidden_size: Q={hidden_size}")
361+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
362+
f.write(f"[FASTVIDEO TRANSFORMER] hidden_size: Q={hidden_size}\n")
363+
364+
320365
self.norm1 = CosmosAdaLayerNormZero(in_features=hidden_size,
321366
hidden_features=adaln_lora_dim)
322367
self.attn1 = CosmosSelfAttention(
@@ -355,18 +400,51 @@ def forward(
355400
hidden_states = hidden_states + extra_pos_emb
356401

357402
# 1. Self Attention
403+
print(f"[FASTVIDEO DEBUG] Before norm1: hidden_states={hidden_states.float().sum().item()}")
404+
print(f"[FASTVIDEO DEBUG] Before norm1: embedded_timestep={embedded_timestep.float().sum().item()}")
405+
print(f"[FASTVIDEO DEBUG] Before norm1: temb={temb.float().sum().item() if temb is not None else 'None'}")
406+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
407+
f.write(f"[FASTVIDEO DEBUG] Before norm1: hidden_states={hidden_states.float().sum().item()}\n")
408+
f.write(f"[FASTVIDEO DEBUG] Before norm1: embedded_timestep={embedded_timestep.float().sum().item()}\n")
409+
f.write(f"[FASTVIDEO DEBUG] Before norm1: temb={temb.float().sum().item() if temb is not None else 'None'}\n")
410+
# Debug norm1 weights
411+
print(f"[FASTVIDEO DEBUG] norm1.linear_1.weight sum: {self.norm1.linear_1.weight.float().sum().item()}")
412+
print(f"[FASTVIDEO DEBUG] norm1.linear_2.weight sum: {self.norm1.linear_2.weight.float().sum().item()}")
413+
print(f"[FASTVIDEO DEBUG] hidden_states dtype: {hidden_states.dtype}")
414+
print(f"[FASTVIDEO DEBUG] embedded_timestep dtype: {embedded_timestep.dtype}")
415+
print(f"[FASTVIDEO DEBUG] temb dtype: {temb.dtype if temb is not None else 'None'}")
416+
print(f"[FASTVIDEO DEBUG] norm1.linear_1.weight dtype: {self.norm1.linear_1.weight.dtype}")
417+
print(f"[FASTVIDEO DEBUG] norm1.linear_2.weight dtype: {self.norm1.linear_2.weight.dtype}")
418+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
419+
f.write(f"[FASTVIDEO DEBUG] norm1.linear_1.weight sum: {self.norm1.linear_1.weight.float().sum().item()}\n")
420+
f.write(f"[FASTVIDEO DEBUG] norm1.linear_2.weight sum: {self.norm1.linear_2.weight.float().sum().item()}\n")
421+
f.write(f"[FASTVIDEO DEBUG] hidden_states dtype: {hidden_states.dtype}\n")
422+
f.write(f"[FASTVIDEO DEBUG] embedded_timestep dtype: {embedded_timestep.dtype}\n")
423+
f.write(f"[FASTVIDEO DEBUG] temb dtype: {temb.dtype if temb is not None else 'None'}\n")
424+
f.write(f"[FASTVIDEO DEBUG] norm1.linear_1.weight dtype: {self.norm1.linear_1.weight.dtype}\n")
425+
f.write(f"[FASTVIDEO DEBUG] norm1.linear_2.weight dtype: {self.norm1.linear_2.weight.dtype}\n")
426+
358427
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep,
359428
temb)
429+
print(f"[FASTVIDEO DEBUG] After norm1: norm_hidden_states={norm_hidden_states.float().sum().item()}")
430+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
431+
f.write(f"[FASTVIDEO DEBUG] After norm1: norm_hidden_states={norm_hidden_states.float().sum().item()}\n")
360432
attn_output = self.attn1(norm_hidden_states,
361433
image_rotary_emb=image_rotary_emb)
362434
hidden_states = hidden_states + gate * attn_output
363435

364436
# 2. Cross Attention
437+
# print(f"[FASTVIDEO] About to call cross-attention")
438+
# with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
439+
# f.write(f"[FASTVIDEO] About to call cross-attention\n")
365440
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep,
366441
temb)
367442
attn_output = self.attn2(norm_hidden_states,
368443
encoder_hidden_states=encoder_hidden_states,
369444
attention_mask=attention_mask)
445+
# print(f"[FASTVIDEO] Cross-attention completed")
446+
# with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
447+
# f.write(f"[FASTVIDEO] Cross-attention completed\n")
370448
hidden_states = hidden_states + gate * attn_output
371449

372450
# 3. Feed Forward
@@ -604,6 +682,8 @@ def forward(self,
604682
padding_mask: torch.Tensor | None = None,
605683
**kwargs) -> torch.Tensor:
606684
print(f"[FASTVIDEO TRANSFORMER] Input hidden_states sum = {hidden_states.float().sum().item()}")
685+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
686+
f.write(f"[FASTVIDEO TRANSFORMER] Input hidden_states sum = {hidden_states.float().sum().item()}\n")
607687
forward_batch = get_forward_context().forward_batch
608688
enable_teacache = forward_batch is not None and forward_batch.enable_teacache
609689

@@ -676,9 +756,19 @@ def forward(self,
676756
else:
677757
raise ValueError(f"Unsupported timestep shape: {timestep.shape}")
678758

759+
print(f"[FASTVIDEO] After patch_embed: {hidden_states.float().sum().item()}")
760+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
761+
f.write(f"[FASTVIDEO] After patch_embed: {hidden_states.float().sum().item()}\n")
762+
print(f"[FASTVIDEO] After time_embed temb: {temb.float().sum().item()}")
763+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
764+
f.write(f"[FASTVIDEO] After time_embed temb: {temb.float().sum().item()}\n")
765+
print(f"[FASTVIDEO] After time_embed embedded_timestep: {embedded_timestep.float().sum().item()}")
766+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
767+
f.write(f"[FASTVIDEO] After time_embed embedded_timestep: {embedded_timestep.float().sum().item()}\n")
768+
679769
# 6. Transformer blocks
680770
if torch.is_grad_enabled() and self.gradient_checkpointing:
681-
for block in self.transformer_blocks:
771+
for i, block in enumerate(self.transformer_blocks):
682772
hidden_states = self._gradient_checkpointing_func(
683773
block,
684774
hidden_states,
@@ -689,8 +779,12 @@ def forward(self,
689779
extra_pos_emb,
690780
attention_mask,
691781
)
782+
if i < 3: # Log first 3 blocks
783+
print(f"[FASTVIDEO] After block {i}: {hidden_states.float().sum().item()}")
784+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
785+
f.write(f"[FASTVIDEO] After block {i}: {hidden_states.float().sum().item()}\n")
692786
else:
693-
for block in self.transformer_blocks:
787+
for i, block in enumerate(self.transformer_blocks):
694788
hidden_states = block(
695789
hidden_states=hidden_states,
696790
encoder_hidden_states=encoder_hidden_states,
@@ -700,10 +794,20 @@ def forward(self,
700794
extra_pos_emb=extra_pos_emb,
701795
attention_mask=attention_mask,
702796
)
797+
if i < 3: # Log first 3 blocks
798+
print(f"[FASTVIDEO] After block! {i}: {hidden_states.float().sum().item()}")
799+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
800+
f.write(f"[FASTVIDEO] After block! {i}: {hidden_states.float().sum().item()}\n")
703801

704802
# 7. Output norm & projection & unpatchify
705803
hidden_states = self.norm_out(hidden_states, embedded_timestep, temb)
804+
print(f"[FASTVIDEO] After norm_out: {hidden_states.float().sum().item()}")
805+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
806+
f.write(f"[FASTVIDEO] After norm_out: {hidden_states.float().sum().item()}\n")
706807
hidden_states = self.proj_out(hidden_states)
808+
print(f"[FASTVIDEO] After proj_out: {hidden_states.float().sum().item()}")
809+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
810+
f.write(f"[FASTVIDEO] After proj_out: {hidden_states.float().sum().item()}\n")
707811
hidden_states = hidden_states.unflatten(2, (p_h, p_w, p_t, -1))
708812
hidden_states = hidden_states.unflatten(
709813
1, (post_patch_num_frames, post_patch_height, post_patch_width))
@@ -713,4 +817,6 @@ def forward(self,
713817
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
714818

715819
print(f"[FASTVIDEO TRANSFORMER] Output hidden_states sum = {hidden_states.float().sum().item()}")
820+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
821+
f.write(f"[FASTVIDEO TRANSFORMER] Output hidden_states sum = {hidden_states.float().sum().item()}\n")
716822
return hidden_states

fastvideo/pipelines/basic/cosmos/cosmos_pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
import numpy as np
1111
import torch
1212

13+
# TEMPORARY: Import diffusers VAE for comparison
14+
import sys
15+
sys.path.insert(0, '/workspace/diffusers/src')
16+
from diffusers.models.autoencoders.autoencoder_kl_wan import AutoencoderKLWan as DiffusersAutoencoderKLWan
17+
1318
from fastvideo.fastvideo_args import FastVideoArgs
1419
from fastvideo.logger import init_logger
1520
from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase
@@ -33,6 +38,23 @@ class Cosmos2VideoToWorldPipeline(ComposedPipelineBase):
3338

3439
def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
3540

41+
# TEMPORARY: Replace FastVideo VAE with diffusers VAE for testing
42+
print("[TEMPORARY] Replacing FastVideo VAE with diffusers VAE...")
43+
original_vae = self.modules["vae"]
44+
print(f"[TEMPORARY] Original VAE type: {type(original_vae)}")
45+
46+
# Load diffusers VAE with same config
47+
diffusers_vae = DiffusersAutoencoderKLWan.from_pretrained(
48+
self.model_path,
49+
subfolder="vae",
50+
torch_dtype=torch.bfloat16,
51+
)
52+
print(f"[TEMPORARY] Diffusers VAE type: {type(diffusers_vae)}")
53+
54+
# Replace the VAE module
55+
self.modules["vae"] = diffusers_vae
56+
print("[TEMPORARY] VAE replacement complete!")
57+
3658
self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler(
3759
shift=fastvideo_args.pipeline_config.flow_shift)
3860

fastvideo/pipelines/stages/decoding.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,16 @@ def forward(
9292
vae_autocast_enabled = (vae_dtype != torch.float32
9393
) and not fastvideo_args.disable_autocast
9494

95-
if isinstance(self.vae.scaling_factor, torch.Tensor):
96-
latents = latents / self.vae.scaling_factor.to(
97-
latents.device, latents.dtype)
98-
else:
99-
latents = latents / self.vae.scaling_factor
95+
# TEMPORARY: Handle diffusers VAE compatibility
96+
if hasattr(self.vae, 'scaling_factor'):
97+
if isinstance(self.vae.scaling_factor, torch.Tensor):
98+
latents = latents / self.vae.scaling_factor.to(
99+
latents.device, latents.dtype)
100+
else:
101+
latents = latents / self.vae.scaling_factor
102+
elif hasattr(self.vae, 'config') and hasattr(self.vae.config, 'scaling_factor'):
103+
# Fallback to config scaling factor for diffusers VAE
104+
latents = latents / self.vae.config.scaling_factor
100105

101106
# Apply shifting if needed
102107
if (hasattr(self.vae, "shift_factor")
@@ -117,7 +122,15 @@ def forward(
117122
# self.vae.enable_parallel()
118123
if not vae_autocast_enabled:
119124
latents = latents.to(vae_dtype)
120-
image = self.vae.decode(latents)
125+
decode_output = self.vae.decode(latents)
126+
127+
# TEMPORARY: Handle diffusers VAE decode output compatibility
128+
if hasattr(decode_output, 'sample'):
129+
# Diffusers VAE returns DecoderOutput with .sample attribute
130+
image = decode_output.sample
131+
else:
132+
# FastVideo VAE returns tensor directly
133+
image = decode_output
121134

122135
# Normalize image to [0, 1] range
123136
image = (image / 2 + 0.5).clamp(0, 1)

fastvideo/pipelines/stages/denoising.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -815,8 +815,9 @@ def forward(
815815
f.write(f" DTYPES: hidden_states={cond_latent.dtype}, timestep={cond_timestep.dtype}, encoder_hidden_states={batch.prompt_embeds[0].dtype}\n")
816816
f.write(f" hidden_states first 5 values: {cond_latent.flatten()[:5].float()}\n")
817817
f.write(f" encoder_hidden_states first 5 values: {batch.prompt_embeds[0].flatten()[:5].float()}\n")
818-
818+
f.write(f" [FASTVIDEO DENOISING] About to call transformer with hidden_states sum = {cond_latent.float().sum().item()}\n")
819819
print(f"[FASTVIDEO DENOISING] About to call transformer with hidden_states sum = {cond_latent.float().sum().item()}")
820+
820821
noise_pred = self.transformer(
821822
hidden_states=cond_latent, # Already converted to target_dtype above
822823
timestep=cond_timestep.to(target_dtype),

0 commit comments

Comments
 (0)