@@ -76,8 +76,14 @@ def __init__(self, embedding_dim: int, condition_dim: int) -> None:
7676 def forward (self , hidden_states : torch .Tensor ,
7777 timestep : torch .LongTensor ) -> torch .Tensor :
7878 timesteps_proj = self .time_proj (timestep ).type_as (hidden_states )
79+ print (f"[FASTVIDEO] timesteps_proj before norm: { timesteps_proj .float ().sum ().item ()} " )
80+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
81+ f .write (f"[FASTVIDEO] timesteps_proj before norm: { timesteps_proj .float ().sum ().item ()} \n " )
7982 temb = self .t_embedder (timesteps_proj )
8083 embedded_timestep = self .norm (timesteps_proj )
84+ print (f"[FASTVIDEO] embedded_timestep after norm: { embedded_timestep .float ().sum ().item ()} " )
85+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
86+ f .write (f"[FASTVIDEO] embedded_timestep after norm: { embedded_timestep .float ().sum ().item ()} \n " )
8187 return temb , embedded_timestep
8288
8389
@@ -133,10 +139,7 @@ def __init__(self,
133139 else :
134140 self .linear_1 = nn .Linear (in_features , hidden_features , bias = False )
135141
136- self .linear_2 = nn .Linear (
137- hidden_features if hidden_features is not None else in_features ,
138- 3 * in_features ,
139- bias = False )
142+ self .linear_2 = nn .Linear (hidden_features , 3 * in_features , bias = False )
140143
141144 def forward (
142145 self ,
@@ -197,10 +200,16 @@ def forward(self,
197200 if encoder_hidden_states is None :
198201 encoder_hidden_states = hidden_states
199202
203+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
204+ f .write (f"[FASTVIDEO SELF-ATTN] INIT hidden_states: Q={ hidden_states .float ().sum ().item ()} \n " )
205+
200206 # Get QKV
201207 query = self .to_q (hidden_states )
202208 key = self .to_k (encoder_hidden_states )
203209 value = self .to_v (encoder_hidden_states )
210+ print (f"[FASTVIDEO SELF-ATTN] QKV sums: Q={ query .float ().sum ().item ()} , K={ key .float ().sum ().item ()} , V={ value .float ().sum ().item ()} " )
211+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
212+ f .write (f"[FASTVIDEO SELF-ATTN] QKV sums: Q={ query .float ().sum ().item ()} , K={ key .float ().sum ().item ()} , V={ value .float ().sum ().item ()} \n " )
204213
205214 # Reshape for multi-head attention
206215 query = query .unflatten (2 , (self .num_heads , - 1 )).transpose (1 , 2 )
@@ -209,9 +218,9 @@ def forward(self,
209218
210219 # Apply normalization
211220 if self .norm_q is not None :
212- query = self .norm_q . forward_native (query )
221+ query = self .norm_q (query )
213222 if self .norm_k is not None :
214- key = self .norm_k . forward_native (key )
223+ key = self .norm_k (key )
215224
216225 # Apply RoPE if provided
217226 if image_rotary_emb is not None :
@@ -224,12 +233,28 @@ def forward(self,
224233 use_real = True ,
225234 use_real_unbind_dim = - 2 )
226235
236+ # Prepare for GQA (Grouped Query Attention)
237+ if torch .onnx .is_in_onnx_export ():
238+ query_idx = torch .tensor (query .size (3 ), device = query .device )
239+ key_idx = torch .tensor (key .size (3 ), device = key .device )
240+ value_idx = torch .tensor (value .size (3 ), device = value .device )
241+ else :
242+ query_idx = query .size (3 )
243+ key_idx = key .size (3 )
244+ value_idx = value .size (3 )
245+ key = key .repeat_interleave (query_idx // key_idx , dim = 3 )
246+ value = value .repeat_interleave (query_idx // value_idx , dim = 3 )
247+
227248 # Attention computation
228249 # Use standard PyTorch scaled dot product attention
229250 attn_output = torch .nn .functional .scaled_dot_product_attention (
230- query , key , value , attn_mask = attention_mask , dropout_p = 0.0
251+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
231252 )
232253 attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
254+ print (f"[FASTVIDEO TRANSFORMER] hidden_states: { attn_output .float ().sum ().item ()} " )
255+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
256+ f .write (f"[FASTVIDEO TRANSFORMER] hidden_states: { attn_output .float ().sum ().item ()} \n " )
257+ f .write (f"self.to_out: { self .to_out } " )
233258
234259 # Output projection
235260 attn_output = self .to_out (attn_output )
@@ -275,6 +300,9 @@ def forward(self,
275300 query = self .to_q (hidden_states )
276301 key = self .to_k (encoder_hidden_states )
277302 value = self .to_v (encoder_hidden_states )
303+ # print(f"[FASTVIDEO CROSS-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}")
304+ # with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
305+ # f.write(f"[FASTVIDEO CROSS-ATTN] QKV sums: Q={query.float().sum().item()}, K={key.float().sum().item()}, V={value.float().sum().item()}\n")
278306
279307 # Reshape for multi-head attention
280308 # Standard PyTorch attention expects [batch, num_heads, seq_len, head_dim]
@@ -284,13 +312,25 @@ def forward(self,
284312
285313 # Apply normalization
286314 if self .norm_q is not None :
287- query = self .norm_q . forward_native (query )
315+ query = self .norm_q (query )
288316 if self .norm_k is not None :
289- key = self .norm_k .forward_native (key )
317+ key = self .norm_k (key )
318+
319+ # Prepare for GQA (Grouped Query Attention)
320+ if torch .onnx .is_in_onnx_export ():
321+ query_idx = torch .tensor (query .size (3 ), device = query .device )
322+ key_idx = torch .tensor (key .size (3 ), device = key .device )
323+ value_idx = torch .tensor (value .size (3 ), device = value .device )
324+ else :
325+ query_idx = query .size (3 )
326+ key_idx = key .size (3 )
327+ value_idx = value .size (3 )
328+ key = key .repeat_interleave (query_idx // key_idx , dim = 3 )
329+ value = value .repeat_interleave (query_idx // value_idx , dim = 3 )
290330
291331 # Attention computation
292332 attn_output = torch .nn .functional .scaled_dot_product_attention (
293- query , key , value , attn_mask = attention_mask , dropout_p = 0.0
333+ query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
294334 )
295335 attn_output = attn_output .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
296336
@@ -317,6 +357,11 @@ def __init__(
317357
318358 hidden_size = num_attention_heads * attention_head_dim
319359
360+ print (f"[FASTVIDEO TRANSFORMER] hidden_size: Q={ hidden_size } " )
361+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
362+ f .write (f"[FASTVIDEO TRANSFORMER] hidden_size: Q={ hidden_size } \n " )
363+
364+
320365 self .norm1 = CosmosAdaLayerNormZero (in_features = hidden_size ,
321366 hidden_features = adaln_lora_dim )
322367 self .attn1 = CosmosSelfAttention (
@@ -355,18 +400,51 @@ def forward(
355400 hidden_states = hidden_states + extra_pos_emb
356401
357402 # 1. Self Attention
403+ print (f"[FASTVIDEO DEBUG] Before norm1: hidden_states={ hidden_states .float ().sum ().item ()} " )
404+ print (f"[FASTVIDEO DEBUG] Before norm1: embedded_timestep={ embedded_timestep .float ().sum ().item ()} " )
405+ print (f"[FASTVIDEO DEBUG] Before norm1: temb={ temb .float ().sum ().item () if temb is not None else 'None' } " )
406+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
407+ f .write (f"[FASTVIDEO DEBUG] Before norm1: hidden_states={ hidden_states .float ().sum ().item ()} \n " )
408+ f .write (f"[FASTVIDEO DEBUG] Before norm1: embedded_timestep={ embedded_timestep .float ().sum ().item ()} \n " )
409+ f .write (f"[FASTVIDEO DEBUG] Before norm1: temb={ temb .float ().sum ().item () if temb is not None else 'None' } \n " )
410+ # Debug norm1 weights
411+ print (f"[FASTVIDEO DEBUG] norm1.linear_1.weight sum: { self .norm1 .linear_1 .weight .float ().sum ().item ()} " )
412+ print (f"[FASTVIDEO DEBUG] norm1.linear_2.weight sum: { self .norm1 .linear_2 .weight .float ().sum ().item ()} " )
413+ print (f"[FASTVIDEO DEBUG] hidden_states dtype: { hidden_states .dtype } " )
414+ print (f"[FASTVIDEO DEBUG] embedded_timestep dtype: { embedded_timestep .dtype } " )
415+ print (f"[FASTVIDEO DEBUG] temb dtype: { temb .dtype if temb is not None else 'None' } " )
416+ print (f"[FASTVIDEO DEBUG] norm1.linear_1.weight dtype: { self .norm1 .linear_1 .weight .dtype } " )
417+ print (f"[FASTVIDEO DEBUG] norm1.linear_2.weight dtype: { self .norm1 .linear_2 .weight .dtype } " )
418+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
419+ f .write (f"[FASTVIDEO DEBUG] norm1.linear_1.weight sum: { self .norm1 .linear_1 .weight .float ().sum ().item ()} \n " )
420+ f .write (f"[FASTVIDEO DEBUG] norm1.linear_2.weight sum: { self .norm1 .linear_2 .weight .float ().sum ().item ()} \n " )
421+ f .write (f"[FASTVIDEO DEBUG] hidden_states dtype: { hidden_states .dtype } \n " )
422+ f .write (f"[FASTVIDEO DEBUG] embedded_timestep dtype: { embedded_timestep .dtype } \n " )
423+ f .write (f"[FASTVIDEO DEBUG] temb dtype: { temb .dtype if temb is not None else 'None' } \n " )
424+ f .write (f"[FASTVIDEO DEBUG] norm1.linear_1.weight dtype: { self .norm1 .linear_1 .weight .dtype } \n " )
425+ f .write (f"[FASTVIDEO DEBUG] norm1.linear_2.weight dtype: { self .norm1 .linear_2 .weight .dtype } \n " )
426+
358427 norm_hidden_states , gate = self .norm1 (hidden_states , embedded_timestep ,
359428 temb )
429+ print (f"[FASTVIDEO DEBUG] After norm1: norm_hidden_states={ norm_hidden_states .float ().sum ().item ()} " )
430+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
431+ f .write (f"[FASTVIDEO DEBUG] After norm1: norm_hidden_states={ norm_hidden_states .float ().sum ().item ()} \n " )
360432 attn_output = self .attn1 (norm_hidden_states ,
361433 image_rotary_emb = image_rotary_emb )
362434 hidden_states = hidden_states + gate * attn_output
363435
364436 # 2. Cross Attention
437+ # print(f"[FASTVIDEO] About to call cross-attention")
438+ # with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
439+ # f.write(f"[FASTVIDEO] About to call cross-attention\n")
365440 norm_hidden_states , gate = self .norm2 (hidden_states , embedded_timestep ,
366441 temb )
367442 attn_output = self .attn2 (norm_hidden_states ,
368443 encoder_hidden_states = encoder_hidden_states ,
369444 attention_mask = attention_mask )
445+ # print(f"[FASTVIDEO] Cross-attention completed")
446+ # with open("/workspace/FastVideo/fastvideo_hidden_states.log", "a") as f:
447+ # f.write(f"[FASTVIDEO] Cross-attention completed\n")
370448 hidden_states = hidden_states + gate * attn_output
371449
372450 # 3. Feed Forward
@@ -604,6 +682,8 @@ def forward(self,
604682 padding_mask : torch .Tensor | None = None ,
605683 ** kwargs ) -> torch .Tensor :
606684 print (f"[FASTVIDEO TRANSFORMER] Input hidden_states sum = { hidden_states .float ().sum ().item ()} " )
685+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
686+ f .write (f"[FASTVIDEO TRANSFORMER] Input hidden_states sum = { hidden_states .float ().sum ().item ()} \n " )
607687 forward_batch = get_forward_context ().forward_batch
608688 enable_teacache = forward_batch is not None and forward_batch .enable_teacache
609689
@@ -676,9 +756,19 @@ def forward(self,
676756 else :
677757 raise ValueError (f"Unsupported timestep shape: { timestep .shape } " )
678758
759+ print (f"[FASTVIDEO] After patch_embed: { hidden_states .float ().sum ().item ()} " )
760+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
761+ f .write (f"[FASTVIDEO] After patch_embed: { hidden_states .float ().sum ().item ()} \n " )
762+ print (f"[FASTVIDEO] After time_embed temb: { temb .float ().sum ().item ()} " )
763+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
764+ f .write (f"[FASTVIDEO] After time_embed temb: { temb .float ().sum ().item ()} \n " )
765+ print (f"[FASTVIDEO] After time_embed embedded_timestep: { embedded_timestep .float ().sum ().item ()} " )
766+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
767+ f .write (f"[FASTVIDEO] After time_embed embedded_timestep: { embedded_timestep .float ().sum ().item ()} \n " )
768+
679769 # 6. Transformer blocks
680770 if torch .is_grad_enabled () and self .gradient_checkpointing :
681- for block in self .transformer_blocks :
771+ for i , block in enumerate ( self .transformer_blocks ) :
682772 hidden_states = self ._gradient_checkpointing_func (
683773 block ,
684774 hidden_states ,
@@ -689,8 +779,12 @@ def forward(self,
689779 extra_pos_emb ,
690780 attention_mask ,
691781 )
782+ if i < 3 : # Log first 3 blocks
783+ print (f"[FASTVIDEO] After block { i } : { hidden_states .float ().sum ().item ()} " )
784+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
785+ f .write (f"[FASTVIDEO] After block { i } : { hidden_states .float ().sum ().item ()} \n " )
692786 else :
693- for block in self .transformer_blocks :
787+ for i , block in enumerate ( self .transformer_blocks ) :
694788 hidden_states = block (
695789 hidden_states = hidden_states ,
696790 encoder_hidden_states = encoder_hidden_states ,
@@ -700,10 +794,20 @@ def forward(self,
700794 extra_pos_emb = extra_pos_emb ,
701795 attention_mask = attention_mask ,
702796 )
797+ if i < 3 : # Log first 3 blocks
798+ print (f"[FASTVIDEO] After block! { i } : { hidden_states .float ().sum ().item ()} " )
799+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
800+ f .write (f"[FASTVIDEO] After block! { i } : { hidden_states .float ().sum ().item ()} \n " )
703801
704802 # 7. Output norm & projection & unpatchify
705803 hidden_states = self .norm_out (hidden_states , embedded_timestep , temb )
804+ print (f"[FASTVIDEO] After norm_out: { hidden_states .float ().sum ().item ()} " )
805+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
806+ f .write (f"[FASTVIDEO] After norm_out: { hidden_states .float ().sum ().item ()} \n " )
706807 hidden_states = self .proj_out (hidden_states )
808+ print (f"[FASTVIDEO] After proj_out: { hidden_states .float ().sum ().item ()} " )
809+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
810+ f .write (f"[FASTVIDEO] After proj_out: { hidden_states .float ().sum ().item ()} \n " )
707811 hidden_states = hidden_states .unflatten (2 , (p_h , p_w , p_t , - 1 ))
708812 hidden_states = hidden_states .unflatten (
709813 1 , (post_patch_num_frames , post_patch_height , post_patch_width ))
@@ -713,4 +817,6 @@ def forward(self,
713817 hidden_states = hidden_states .flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
714818
715819 print (f"[FASTVIDEO TRANSFORMER] Output hidden_states sum = { hidden_states .float ().sum ().item ()} " )
820+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
821+ f .write (f"[FASTVIDEO TRANSFORMER] Output hidden_states sum = { hidden_states .float ().sum ().item ()} \n " )
716822 return hidden_states
0 commit comments