Skip to content

Commit 83c5dc3

Browse files
committed
Fix norm
1 parent f747b40 commit 83c5dc3

File tree

4 files changed

+173
-24
lines changed

4 files changed

+173
-24
lines changed

fastvideo/layers/layernorm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ def __init__(
3939
if self.has_weight:
4040
self.weight = nn.Parameter(self.weight)
4141

42+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
43+
"""Forward method that matches Diffusers RMSNorm implementation exactly."""
44+
input_dtype = hidden_states.dtype
45+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
46+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
47+
48+
if self.has_weight and self.weight is not None:
49+
# convert into half-precision if necessary (match Diffusers exactly)
50+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
51+
hidden_states = hidden_states.to(self.weight.dtype)
52+
hidden_states = hidden_states * self.weight
53+
else:
54+
hidden_states = hidden_states.to(input_dtype)
55+
56+
return hidden_states
57+
4258
# if we do fully_shard(model.layer_norm), and we call layer_form.forward_native(input) instead of layer_norm(input),
4359
# we need to call model.layer_norm.register_fsdp_forward_method(model, "forward_native") to make sure fsdp2 hooks are triggered
4460
# for mixed precision and cpu offloading

fastvideo/layers/rotary_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def apply_rotary_emb(
6464
"""
6565
if use_real:
6666
cos, sin = freqs_cis # [S, D]
67-
cos = cos[None, None]
68-
sin = sin[None, None]
67+
# Match Diffusers exact broadcasting (sequence_dim=2 case)
68+
cos = cos[None, None, :, :]
69+
sin = sin[None, None, :, :]
6970
cos, sin = cos.to(x.device), sin.to(x.device)
7071

7172
if use_real_unbind_dim == -1:

fastvideo/models/dits/cosmos.py

Lines changed: 86 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,9 @@ def forward(self,
113113
self.embedding_dim]
114114

115115
shift, scale = embedded_timestep.chunk(2, dim=-1)
116-
hidden_states = self.norm(hidden_states)
116+
# Disable autocast for LayerNorm to match Diffusers behavior
117+
with torch.autocast(device_type="cuda", enabled=False):
118+
hidden_states = self.norm(hidden_states)
117119

118120
if embedded_timestep.ndim == 2:
119121
shift, scale = (x.unsqueeze(1) for x in (shift, scale))
@@ -147,6 +149,9 @@ def forward(
147149
embedded_timestep: torch.Tensor,
148150
temb: torch.Tensor | None = None,
149151
) -> torch.Tensor:
152+
instance_id = id(self)
153+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
154+
f.write(f"[FASTVIDEO NORM] Instance {instance_id}: forward hidden_states: {hidden_states.float().sum().item()}\n")
150155
embedded_timestep = self.activation(embedded_timestep)
151156
embedded_timestep = self.linear_1(embedded_timestep)
152157
embedded_timestep = self.linear_2(embedded_timestep)
@@ -155,8 +160,45 @@ def forward(
155160
embedded_timestep = embedded_timestep + temb
156161

157162
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
158-
hidden_states = self.norm(hidden_states)
159-
163+
print(f"[FASTVIDEO NORM] After chunk - shift sum: {shift.float().sum().item()}")
164+
print(f"[FASTVIDEO NORM] After chunk - scale sum: {scale.float().sum().item()}")
165+
print(f"[FASTVIDEO NORM] After chunk - gate sum: {gate.float().sum().item()}")
166+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
167+
f.write(f"[FASTVIDEO NORM] After chunk - shift sum: {shift.float().sum().item()}\n")
168+
f.write(f"[FASTVIDEO NORM] After chunk - scale sum: {scale.float().sum().item()}\n")
169+
f.write(f"[FASTVIDEO NORM] After chunk - gate sum: {gate.float().sum().item()}\n")
170+
print(f"[FASTVIDEO NORM] Before LayerNorm - input shape: {hidden_states.shape}")
171+
print(f"[FASTVIDEO NORM] Before LayerNorm - input dtype: {hidden_states.dtype}")
172+
print(f"[FASTVIDEO NORM] Before LayerNorm - input sum: {hidden_states.float().sum().item()}")
173+
print(f"[FASTVIDEO NORM] LayerNorm eps: {self.norm.eps}")
174+
print(f"[FASTVIDEO NORM] LayerNorm elementwise_affine: {self.norm.elementwise_affine}")
175+
print(f"[FASTVIDEO NORM] LayerNorm normalized_shape: {self.norm.normalized_shape}")
176+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
177+
f.write(f"[FASTVIDEO NORM] Before LayerNorm - input shape: {hidden_states.shape}\n")
178+
f.write(f"[FASTVIDEO NORM] Before LayerNorm - input dtype: {hidden_states.dtype}\n")
179+
f.write(f"[FASTVIDEO NORM] Before LayerNorm - input sum: {hidden_states.float().sum().item()}\n")
180+
f.write(f"[FASTVIDEO NORM] LayerNorm eps: {self.norm.eps}\n")
181+
f.write(f"[FASTVIDEO NORM] LayerNorm elementwise_affine: {self.norm.elementwise_affine}\n")
182+
f.write(f"[FASTVIDEO NORM] LayerNorm normalized_shape: {self.norm.normalized_shape}\n")
183+
184+
# Save the input tensor for comparison (only once globally)
185+
import os
186+
if not hasattr(CosmosAdaLayerNormZero, '_global_tensor_saved'):
187+
instance_id = id(self)
188+
torch.save(hidden_states.float(), "/workspace/FastVideo/fastvideo_layernorm_input.pt")
189+
print(f"[FASTVIDEO NORM] Instance {instance_id}: Saved input tensor sum={hidden_states.float().sum().item()}")
190+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
191+
f.write(f"[FASTVIDEO NORM] Instance {instance_id}: Saved input tensor sum={hidden_states.float().sum().item()}\n")
192+
CosmosAdaLayerNormZero._global_tensor_saved = True
193+
194+
# Disable autocast for LayerNorm to match Diffusers behavior
195+
with torch.autocast(device_type="cuda", enabled=False):
196+
hidden_states = self.norm(hidden_states)
197+
198+
print(f"[FASTVIDEO NORM] After LayerNorm - output sum: {hidden_states.float().sum().item()}")
199+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
200+
f.write(f"[FASTVIDEO NORM] After norm: {hidden_states.float().sum().item()}\n")
201+
f.write(f"embedded_timestep.ndim: {embedded_timestep.ndim}\n")
160202
if embedded_timestep.ndim == 2:
161203
shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
162204

@@ -185,6 +227,7 @@ def __init__(self,
185227
self.to_k = nn.Linear(dim, dim, bias=False)
186228
self.to_v = nn.Linear(dim, dim, bias=False)
187229
self.to_out = nn.Linear(dim, dim, bias=False)
230+
self.dropout = nn.Dropout(0.0) # Match Diffusers dropout
188231

189232
self.norm_q = RMSNorm(self.head_dim,
190233
eps=eps) if qk_norm else nn.Identity()
@@ -215,15 +258,36 @@ def forward(self,
215258
query = query.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
216259
key = key.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
217260
value = value.unflatten(2, (self.num_heads, -1)).transpose(1, 2)
261+
print(f"[FASTVIDEO ATTN] After reshape - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}")
262+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
263+
f.write(f"[FASTVIDEO ATTN] After reshape - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}\n")
218264

219265
# Apply normalization
266+
print(f"[FASTVIDEO ATTN] norm_q is not None: {self.norm_q is not None}, norm_k is not None: {self.norm_k is not None}")
267+
print(f"[FASTVIDEO ATTN] norm_q type: {type(self.norm_q)}, norm_k type: {type(self.norm_k)}")
268+
print(f"[FASTVIDEO ATTN] norm_q eps: {getattr(self.norm_q, 'variance_epsilon', 'N/A')}, norm_k eps: {getattr(self.norm_k, 'variance_epsilon', 'N/A')}")
269+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
270+
f.write(f"[FASTVIDEO ATTN] norm_q is not None: {self.norm_q is not None}, norm_k is not None: {self.norm_k is not None}\n")
271+
f.write(f"[FASTVIDEO ATTN] norm_q type: {type(self.norm_q)}, norm_k type: {type(self.norm_k)}\n")
272+
f.write(f"[FASTVIDEO ATTN] norm_q eps: {getattr(self.norm_q, 'variance_epsilon', 'N/A')}, norm_k eps: {getattr(self.norm_k, 'variance_epsilon', 'N/A')}\n")
220273
if self.norm_q is not None:
221274
query = self.norm_q(query)
222275
if self.norm_k is not None:
223276
key = self.norm_k(key)
277+
print(f"[FASTVIDEO ATTN] After norm - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}")
278+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
279+
f.write(f"[FASTVIDEO ATTN] After norm - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}\n")
224280

225281
# Apply RoPE if provided
226282
if image_rotary_emb is not None:
283+
print(f"[FASTVIDEO ATTN] RoPE input shape: query={query.shape}, image_rotary_emb={len(image_rotary_emb) if isinstance(image_rotary_emb, tuple) else image_rotary_emb.shape}")
284+
print(f"[FASTVIDEO ATTN] RoPE freqs shapes: cos={image_rotary_emb[0].shape}, sin={image_rotary_emb[1].shape}")
285+
print(f"[FASTVIDEO ATTN] RoPE freqs sums: cos={image_rotary_emb[0].float().sum().item()}, sin={image_rotary_emb[1].float().sum().item()}")
286+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
287+
f.write(f"[FASTVIDEO ATTN] RoPE input shape: query={query.shape}, image_rotary_emb={len(image_rotary_emb) if isinstance(image_rotary_emb, tuple) else image_rotary_emb.shape}\n")
288+
f.write(f"[FASTVIDEO ATTN] RoPE freqs shapes: cos={image_rotary_emb[0].shape}, sin={image_rotary_emb[1].shape}\n")
289+
f.write(f"[FASTVIDEO ATTN] RoPE freqs sums: cos={image_rotary_emb[0].float().sum().item()}, sin={image_rotary_emb[1].float().sum().item()}\n")
290+
227291
query = apply_rotary_emb(query,
228292
image_rotary_emb,
229293
use_real=True,
@@ -232,6 +296,9 @@ def forward(self,
232296
image_rotary_emb,
233297
use_real=True,
234298
use_real_unbind_dim=-2)
299+
print(f"[FASTVIDEO ATTN] After RoPE - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}")
300+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
301+
f.write(f"[FASTVIDEO ATTN] After RoPE - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}\n")
235302

236303
# Prepare for GQA (Grouped Query Attention)
237304
if torch.onnx.is_in_onnx_export():
@@ -244,6 +311,11 @@ def forward(self,
244311
value_idx = value.size(3)
245312
key = key.repeat_interleave(query_idx // key_idx, dim=3)
246313
value = value.repeat_interleave(query_idx // value_idx, dim=3)
314+
print(f"[FASTVIDEO ATTN] After GQA - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}")
315+
print(f"[FASTVIDEO ATTN] GQA indices - query_idx: {query_idx}, key_idx: {key_idx}, value_idx: {value_idx}")
316+
with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
317+
f.write(f"[FASTVIDEO ATTN] After GQA - Q: {query.float().sum().item()}, K: {key.float().sum().item()}, V: {value.float().sum().item()}\n")
318+
f.write(f"[FASTVIDEO ATTN] GQA indices - query_idx: {query_idx}, key_idx: {key_idx}, value_idx: {value_idx}\n")
247319

248320
# Attention computation
249321
# Use standard PyTorch scaled dot product attention
@@ -258,6 +330,7 @@ def forward(self,
258330

259331
# Output projection
260332
attn_output = self.to_out(attn_output)
333+
attn_output = self.dropout(attn_output)
261334

262335
return attn_output
263336

@@ -285,6 +358,7 @@ def __init__(self,
285358
self.to_k = nn.Linear(cross_attention_dim, dim, bias=False)
286359
self.to_v = nn.Linear(cross_attention_dim, dim, bias=False)
287360
self.to_out = nn.Linear(dim, dim, bias=False)
361+
self.dropout = nn.Dropout(0.0) # Match Diffusers dropout
288362

289363
self.norm_q = RMSNorm(self.head_dim,
290364
eps=eps) if qk_norm else nn.Identity()
@@ -336,6 +410,7 @@ def forward(self,
336410

337411
# Output projection
338412
attn_output = self.to_out(attn_output)
413+
attn_output = self.dropout(attn_output)
339414

340415
return attn_output
341416

@@ -368,6 +443,7 @@ def __init__(
368443
dim=hidden_size,
369444
num_heads=num_attention_heads,
370445
qk_norm=(qk_norm == "rms_norm"),
446+
eps=1e-5, # Match Diffusers default
371447
prefix=f"{prefix}.attn1")
372448

373449
self.norm2 = CosmosAdaLayerNormZero(in_features=hidden_size,
@@ -377,6 +453,7 @@ def __init__(
377453
cross_attention_dim=cross_attention_dim,
378454
num_heads=num_attention_heads,
379455
qk_norm=(qk_norm == "rms_norm"),
456+
eps=1e-5, # Match Diffusers default
380457
prefix=f"{prefix}.attn2")
381458

382459
self.norm3 = CosmosAdaLayerNormZero(in_features=hidden_size,
@@ -697,27 +774,14 @@ def forward(self,
697774
if condition_mask is not None:
698775
hidden_states = torch.cat([hidden_states, condition_mask], dim=1)
699776

700-
if self.concat_padding_mask and padding_mask is not None:
777+
if self.concat_padding_mask:
701778
from torchvision import transforms
702779
padding_mask = transforms.functional.resize(
703-
padding_mask,
704-
list(hidden_states.shape[-2:]),
705-
interpolation=transforms.InterpolationMode.NEAREST)
706-
hidden_states = torch.cat([
707-
hidden_states,
708-
padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1,
709-
1)
710-
],
711-
dim=1)
712-
# # Resize padding mask to match hidden states spatial dimensions
713-
# padding_mask_resized = F.interpolate(
714-
# padding_mask.float().unsqueeze(1),
715-
# size=(height, width),
716-
# mode='nearest'
717-
# ).squeeze(1)
718-
# hidden_states = torch.cat(
719-
# [hidden_states, padding_mask_resized.unsqueeze(1).unsqueeze(2).repeat(1, 1, num_frames, 1, 1)], dim=1
720-
# )
780+
padding_mask, list(hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
781+
)
782+
hidden_states = torch.cat(
783+
[hidden_states, padding_mask.unsqueeze(2).repeat(batch_size, 1, num_frames, 1, 1)], dim=1
784+
)
721785

722786
if attention_mask is not None:
723787
attention_mask = attention_mask.unsqueeze(1).unsqueeze(
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Utility functions for pipeline stages.
4+
"""
5+
6+
import inspect
7+
from typing import List, Optional, Union
8+
9+
import torch
10+
11+
12+
def retrieve_timesteps(
13+
scheduler,
14+
num_inference_steps: Optional[int] = None,
15+
device: Optional[Union[str, torch.device]] = None,
16+
timesteps: Optional[List[int]] = None,
17+
sigmas: Optional[List[float]] = None,
18+
**kwargs,
19+
):
20+
"""
21+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
22+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
23+
24+
Args:
25+
scheduler (`SchedulerMixin`):
26+
The scheduler to get timesteps from.
27+
num_inference_steps (`int`):
28+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
29+
must be `None`.
30+
device (`str` or `torch.device`, *optional*):
31+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
32+
timesteps (`List[int]`, *optional*):
33+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
34+
`num_inference_steps` and `sigmas` must be `None`.
35+
sigmas (`List[float]`, *optional*):
36+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
37+
`num_inference_steps` and `timesteps` must be `None`.
38+
39+
Returns:
40+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule and the
41+
second element is the number of inference steps.
42+
"""
43+
if timesteps is not None and sigmas is not None:
44+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
45+
if timesteps is not None:
46+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
47+
if not accepts_timesteps:
48+
raise ValueError(
49+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
50+
f" timestep schedules. Please check whether you are using the correct scheduler."
51+
)
52+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
53+
timesteps = scheduler.timesteps
54+
num_inference_steps = len(timesteps)
55+
elif sigmas is not None:
56+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
57+
if not accept_sigmas:
58+
raise ValueError(
59+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
60+
f" sigmas schedules. Please check whether you are using the correct scheduler."
61+
)
62+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
63+
timesteps = scheduler.timesteps
64+
num_inference_steps = len(timesteps)
65+
else:
66+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
67+
timesteps = scheduler.timesteps
68+
return timesteps, num_inference_steps

0 commit comments

Comments
 (0)