@@ -113,7 +113,9 @@ def forward(self,
113113 self .embedding_dim ]
114114
115115 shift , scale = embedded_timestep .chunk (2 , dim = - 1 )
116- hidden_states = self .norm (hidden_states )
116+ # Disable autocast for LayerNorm to match Diffusers behavior
117+ with torch .autocast (device_type = "cuda" , enabled = False ):
118+ hidden_states = self .norm (hidden_states )
117119
118120 if embedded_timestep .ndim == 2 :
119121 shift , scale = (x .unsqueeze (1 ) for x in (shift , scale ))
@@ -147,6 +149,9 @@ def forward(
147149 embedded_timestep : torch .Tensor ,
148150 temb : torch .Tensor | None = None ,
149151 ) -> torch .Tensor :
152+ instance_id = id (self )
153+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
154+ f .write (f"[FASTVIDEO NORM] Instance { instance_id } : forward hidden_states: { hidden_states .float ().sum ().item ()} \n " )
150155 embedded_timestep = self .activation (embedded_timestep )
151156 embedded_timestep = self .linear_1 (embedded_timestep )
152157 embedded_timestep = self .linear_2 (embedded_timestep )
@@ -155,8 +160,45 @@ def forward(
155160 embedded_timestep = embedded_timestep + temb
156161
157162 shift , scale , gate = embedded_timestep .chunk (3 , dim = - 1 )
158- hidden_states = self .norm (hidden_states )
159-
163+ print (f"[FASTVIDEO NORM] After chunk - shift sum: { shift .float ().sum ().item ()} " )
164+ print (f"[FASTVIDEO NORM] After chunk - scale sum: { scale .float ().sum ().item ()} " )
165+ print (f"[FASTVIDEO NORM] After chunk - gate sum: { gate .float ().sum ().item ()} " )
166+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
167+ f .write (f"[FASTVIDEO NORM] After chunk - shift sum: { shift .float ().sum ().item ()} \n " )
168+ f .write (f"[FASTVIDEO NORM] After chunk - scale sum: { scale .float ().sum ().item ()} \n " )
169+ f .write (f"[FASTVIDEO NORM] After chunk - gate sum: { gate .float ().sum ().item ()} \n " )
170+ print (f"[FASTVIDEO NORM] Before LayerNorm - input shape: { hidden_states .shape } " )
171+ print (f"[FASTVIDEO NORM] Before LayerNorm - input dtype: { hidden_states .dtype } " )
172+ print (f"[FASTVIDEO NORM] Before LayerNorm - input sum: { hidden_states .float ().sum ().item ()} " )
173+ print (f"[FASTVIDEO NORM] LayerNorm eps: { self .norm .eps } " )
174+ print (f"[FASTVIDEO NORM] LayerNorm elementwise_affine: { self .norm .elementwise_affine } " )
175+ print (f"[FASTVIDEO NORM] LayerNorm normalized_shape: { self .norm .normalized_shape } " )
176+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
177+ f .write (f"[FASTVIDEO NORM] Before LayerNorm - input shape: { hidden_states .shape } \n " )
178+ f .write (f"[FASTVIDEO NORM] Before LayerNorm - input dtype: { hidden_states .dtype } \n " )
179+ f .write (f"[FASTVIDEO NORM] Before LayerNorm - input sum: { hidden_states .float ().sum ().item ()} \n " )
180+ f .write (f"[FASTVIDEO NORM] LayerNorm eps: { self .norm .eps } \n " )
181+ f .write (f"[FASTVIDEO NORM] LayerNorm elementwise_affine: { self .norm .elementwise_affine } \n " )
182+ f .write (f"[FASTVIDEO NORM] LayerNorm normalized_shape: { self .norm .normalized_shape } \n " )
183+
184+ # Save the input tensor for comparison (only once globally)
185+ import os
186+ if not hasattr (CosmosAdaLayerNormZero , '_global_tensor_saved' ):
187+ instance_id = id (self )
188+ torch .save (hidden_states .float (), "/workspace/FastVideo/fastvideo_layernorm_input.pt" )
189+ print (f"[FASTVIDEO NORM] Instance { instance_id } : Saved input tensor sum={ hidden_states .float ().sum ().item ()} " )
190+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
191+ f .write (f"[FASTVIDEO NORM] Instance { instance_id } : Saved input tensor sum={ hidden_states .float ().sum ().item ()} \n " )
192+ CosmosAdaLayerNormZero ._global_tensor_saved = True
193+
194+ # Disable autocast for LayerNorm to match Diffusers behavior
195+ with torch .autocast (device_type = "cuda" , enabled = False ):
196+ hidden_states = self .norm (hidden_states )
197+
198+ print (f"[FASTVIDEO NORM] After LayerNorm - output sum: { hidden_states .float ().sum ().item ()} " )
199+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
200+ f .write (f"[FASTVIDEO NORM] After norm: { hidden_states .float ().sum ().item ()} \n " )
201+ f .write (f"embedded_timestep.ndim: { embedded_timestep .ndim } \n " )
160202 if embedded_timestep .ndim == 2 :
161203 shift , scale , gate = (x .unsqueeze (1 ) for x in (shift , scale , gate ))
162204
@@ -185,6 +227,7 @@ def __init__(self,
185227 self .to_k = nn .Linear (dim , dim , bias = False )
186228 self .to_v = nn .Linear (dim , dim , bias = False )
187229 self .to_out = nn .Linear (dim , dim , bias = False )
230+ self .dropout = nn .Dropout (0.0 ) # Match Diffusers dropout
188231
189232 self .norm_q = RMSNorm (self .head_dim ,
190233 eps = eps ) if qk_norm else nn .Identity ()
@@ -215,15 +258,36 @@ def forward(self,
215258 query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
216259 key = key .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
217260 value = value .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
261+ print (f"[FASTVIDEO ATTN] After reshape - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} " )
262+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
263+ f .write (f"[FASTVIDEO ATTN] After reshape - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} \n " )
218264
219265 # Apply normalization
266+ print (f"[FASTVIDEO ATTN] norm_q is not None: { self .norm_q is not None } , norm_k is not None: { self .norm_k is not None } " )
267+ print (f"[FASTVIDEO ATTN] norm_q type: { type (self .norm_q )} , norm_k type: { type (self .norm_k )} " )
268+ print (f"[FASTVIDEO ATTN] norm_q eps: { getattr (self .norm_q , 'variance_epsilon' , 'N/A' )} , norm_k eps: { getattr (self .norm_k , 'variance_epsilon' , 'N/A' )} " )
269+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
270+ f .write (f"[FASTVIDEO ATTN] norm_q is not None: { self .norm_q is not None } , norm_k is not None: { self .norm_k is not None } \n " )
271+ f .write (f"[FASTVIDEO ATTN] norm_q type: { type (self .norm_q )} , norm_k type: { type (self .norm_k )} \n " )
272+ f .write (f"[FASTVIDEO ATTN] norm_q eps: { getattr (self .norm_q , 'variance_epsilon' , 'N/A' )} , norm_k eps: { getattr (self .norm_k , 'variance_epsilon' , 'N/A' )} \n " )
220273 if self .norm_q is not None :
221274 query = self .norm_q (query )
222275 if self .norm_k is not None :
223276 key = self .norm_k (key )
277+ print (f"[FASTVIDEO ATTN] After norm - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} " )
278+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
279+ f .write (f"[FASTVIDEO ATTN] After norm - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} \n " )
224280
225281 # Apply RoPE if provided
226282 if image_rotary_emb is not None :
283+ print (f"[FASTVIDEO ATTN] RoPE input shape: query={ query .shape } , image_rotary_emb={ len (image_rotary_emb ) if isinstance (image_rotary_emb , tuple ) else image_rotary_emb .shape } " )
284+ print (f"[FASTVIDEO ATTN] RoPE freqs shapes: cos={ image_rotary_emb [0 ].shape } , sin={ image_rotary_emb [1 ].shape } " )
285+ print (f"[FASTVIDEO ATTN] RoPE freqs sums: cos={ image_rotary_emb [0 ].float ().sum ().item ()} , sin={ image_rotary_emb [1 ].float ().sum ().item ()} " )
286+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
287+ f .write (f"[FASTVIDEO ATTN] RoPE input shape: query={ query .shape } , image_rotary_emb={ len (image_rotary_emb ) if isinstance (image_rotary_emb , tuple ) else image_rotary_emb .shape } \n " )
288+ f .write (f"[FASTVIDEO ATTN] RoPE freqs shapes: cos={ image_rotary_emb [0 ].shape } , sin={ image_rotary_emb [1 ].shape } \n " )
289+ f .write (f"[FASTVIDEO ATTN] RoPE freqs sums: cos={ image_rotary_emb [0 ].float ().sum ().item ()} , sin={ image_rotary_emb [1 ].float ().sum ().item ()} \n " )
290+
227291 query = apply_rotary_emb (query ,
228292 image_rotary_emb ,
229293 use_real = True ,
@@ -232,6 +296,9 @@ def forward(self,
232296 image_rotary_emb ,
233297 use_real = True ,
234298 use_real_unbind_dim = - 2 )
299+ print (f"[FASTVIDEO ATTN] After RoPE - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} " )
300+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
301+ f .write (f"[FASTVIDEO ATTN] After RoPE - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} \n " )
235302
236303 # Prepare for GQA (Grouped Query Attention)
237304 if torch .onnx .is_in_onnx_export ():
@@ -244,6 +311,11 @@ def forward(self,
244311 value_idx = value .size (3 )
245312 key = key .repeat_interleave (query_idx // key_idx , dim = 3 )
246313 value = value .repeat_interleave (query_idx // value_idx , dim = 3 )
314+ print (f"[FASTVIDEO ATTN] After GQA - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} " )
315+ print (f"[FASTVIDEO ATTN] GQA indices - query_idx: { query_idx } , key_idx: { key_idx } , value_idx: { value_idx } " )
316+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
317+ f .write (f"[FASTVIDEO ATTN] After GQA - Q: { query .float ().sum ().item ()} , K: { key .float ().sum ().item ()} , V: { value .float ().sum ().item ()} \n " )
318+ f .write (f"[FASTVIDEO ATTN] GQA indices - query_idx: { query_idx } , key_idx: { key_idx } , value_idx: { value_idx } \n " )
247319
248320 # Attention computation
249321 # Use standard PyTorch scaled dot product attention
@@ -258,6 +330,7 @@ def forward(self,
258330
259331 # Output projection
260332 attn_output = self .to_out (attn_output )
333+ attn_output = self .dropout (attn_output )
261334
262335 return attn_output
263336
@@ -285,6 +358,7 @@ def __init__(self,
285358 self .to_k = nn .Linear (cross_attention_dim , dim , bias = False )
286359 self .to_v = nn .Linear (cross_attention_dim , dim , bias = False )
287360 self .to_out = nn .Linear (dim , dim , bias = False )
361+ self .dropout = nn .Dropout (0.0 ) # Match Diffusers dropout
288362
289363 self .norm_q = RMSNorm (self .head_dim ,
290364 eps = eps ) if qk_norm else nn .Identity ()
@@ -336,6 +410,7 @@ def forward(self,
336410
337411 # Output projection
338412 attn_output = self .to_out (attn_output )
413+ attn_output = self .dropout (attn_output )
339414
340415 return attn_output
341416
@@ -368,6 +443,7 @@ def __init__(
368443 dim = hidden_size ,
369444 num_heads = num_attention_heads ,
370445 qk_norm = (qk_norm == "rms_norm" ),
446+ eps = 1e-5 , # Match Diffusers default
371447 prefix = f"{ prefix } .attn1" )
372448
373449 self .norm2 = CosmosAdaLayerNormZero (in_features = hidden_size ,
@@ -377,6 +453,7 @@ def __init__(
377453 cross_attention_dim = cross_attention_dim ,
378454 num_heads = num_attention_heads ,
379455 qk_norm = (qk_norm == "rms_norm" ),
456+ eps = 1e-5 , # Match Diffusers default
380457 prefix = f"{ prefix } .attn2" )
381458
382459 self .norm3 = CosmosAdaLayerNormZero (in_features = hidden_size ,
@@ -697,27 +774,14 @@ def forward(self,
697774 if condition_mask is not None :
698775 hidden_states = torch .cat ([hidden_states , condition_mask ], dim = 1 )
699776
700- if self .concat_padding_mask and padding_mask is not None :
777+ if self .concat_padding_mask :
701778 from torchvision import transforms
702779 padding_mask = transforms .functional .resize (
703- padding_mask ,
704- list (hidden_states .shape [- 2 :]),
705- interpolation = transforms .InterpolationMode .NEAREST )
706- hidden_states = torch .cat ([
707- hidden_states ,
708- padding_mask .unsqueeze (2 ).repeat (batch_size , 1 , num_frames , 1 ,
709- 1 )
710- ],
711- dim = 1 )
712- # # Resize padding mask to match hidden states spatial dimensions
713- # padding_mask_resized = F.interpolate(
714- # padding_mask.float().unsqueeze(1),
715- # size=(height, width),
716- # mode='nearest'
717- # ).squeeze(1)
718- # hidden_states = torch.cat(
719- # [hidden_states, padding_mask_resized.unsqueeze(1).unsqueeze(2).repeat(1, 1, num_frames, 1, 1)], dim=1
720- # )
780+ padding_mask , list (hidden_states .shape [- 2 :]), interpolation = transforms .InterpolationMode .NEAREST
781+ )
782+ hidden_states = torch .cat (
783+ [hidden_states , padding_mask .unsqueeze (2 ).repeat (batch_size , 1 , num_frames , 1 , 1 )], dim = 1
784+ )
721785
722786 if attention_mask is not None :
723787 attention_mask = attention_mask .unsqueeze (1 ).unsqueeze (
0 commit comments