@@ -545,6 +545,11 @@ def __init__(
545545 ) -> None :
546546 super ().__init__ ()
547547
548+ # Log RoPE parameters
549+ print (f"[FASTVIDEO ROPE INIT] hidden_size={ hidden_size } , max_size={ max_size } , patch_size={ patch_size } , base_fps={ base_fps } , rope_scale={ rope_scale } " )
550+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
551+ f .write (f"[FASTVIDEO ROPE INIT] hidden_size={ hidden_size } , max_size={ max_size } , patch_size={ patch_size } , base_fps={ base_fps } , rope_scale={ rope_scale } \n " )
552+
548553 self .max_size = [
549554 size // patch
550555 for size , patch in zip (max_size , patch_size , strict = False )
@@ -560,20 +565,39 @@ def __init__(
560565 self .w_ntk_factor = rope_scale [2 ]** (self .dim_w / (self .dim_w - 2 ))
561566 self .t_ntk_factor = rope_scale [0 ]** (self .dim_t / (self .dim_t - 2 ))
562567
568+ print (f"[FASTVIDEO ROPE INIT] dim_h={ self .dim_h } , dim_w={ self .dim_w } , dim_t={ self .dim_t } " )
569+ print (f"[FASTVIDEO ROPE INIT] h_ntk_factor={ self .h_ntk_factor } , w_ntk_factor={ self .w_ntk_factor } , t_ntk_factor={ self .t_ntk_factor } " )
570+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
571+ f .write (f"[FASTVIDEO ROPE INIT] dim_h={ self .dim_h } , dim_w={ self .dim_w } , dim_t={ self .dim_t } \n " )
572+ f .write (f"[FASTVIDEO ROPE INIT] h_ntk_factor={ self .h_ntk_factor } , w_ntk_factor={ self .w_ntk_factor } , t_ntk_factor={ self .t_ntk_factor } \n " )
573+
563574 def forward (self ,
564575 hidden_states : torch .Tensor ,
565576 fps : int | None = None ) -> tuple [torch .Tensor , torch .Tensor ]:
577+ fps = 16
566578 batch_size , num_channels , num_frames , height , width = hidden_states .shape
567579 pe_size = [
568580 num_frames // self .patch_size [0 ], height // self .patch_size [1 ],
569581 width // self .patch_size [2 ]
570582 ]
571583 device = hidden_states .device
572584
585+ print (f"[FASTVIDEO ROPE FORWARD] fps={ fps } , base_fps={ self .base_fps } " )
586+ print (f"[FASTVIDEO ROPE FORWARD] pe_size={ pe_size } , patch_size={ self .patch_size } " )
587+ print (f"[FASTVIDEO ROPE FORWARD] hidden_states.shape={ hidden_states .shape } " )
588+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
589+ f .write (f"[FASTVIDEO ROPE FORWARD] fps={ fps } , base_fps={ self .base_fps } \n " )
590+ f .write (f"[FASTVIDEO ROPE FORWARD] pe_size={ pe_size } , patch_size={ self .patch_size } \n " )
591+ f .write (f"[FASTVIDEO ROPE FORWARD] hidden_states.shape={ hidden_states .shape } \n " )
592+
573593 h_theta = 10000.0 * self .h_ntk_factor
574594 w_theta = 10000.0 * self .w_ntk_factor
575595 t_theta = 10000.0 * self .t_ntk_factor
576596
597+ print (f"[FASTVIDEO ROPE FORWARD] h_theta={ h_theta } , w_theta={ w_theta } , t_theta={ t_theta } " )
598+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
599+ f .write (f"[FASTVIDEO ROPE FORWARD] h_theta={ h_theta } , w_theta={ w_theta } , t_theta={ t_theta } \n " )
600+
577601 seq = torch .arange (max (self .max_size ),
578602 device = device ,
579603 dtype = torch .float32 )
@@ -586,10 +610,20 @@ def forward(self,
586610 dim_t_range = (
587611 torch .arange (0 , self .dim_t , 2 , device = device ,
588612 dtype = torch .float32 )[:(self .dim_t // 2 )] / self .dim_t )
613+ print (f"[FASTVIDEO ROPE FORWARD] max_size={ self .max_size } , seq.shape={ seq .shape } " )
614+ print (f"[FASTVIDEO ROPE FORWARD] dim_h_range.shape={ dim_h_range .shape } , dim_w_range.shape={ dim_w_range .shape } , dim_t_range.shape={ dim_t_range .shape } " )
615+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
616+ f .write (f"[FASTVIDEO ROPE FORWARD] max_size={ self .max_size } , seq.shape={ seq .shape } \n " )
617+ f .write (f"[FASTVIDEO ROPE FORWARD] dim_h_range.shape={ dim_h_range .shape } , dim_w_range.shape={ dim_w_range .shape } , dim_t_range.shape={ dim_t_range .shape } \n " )
618+
589619 h_spatial_freqs = 1.0 / (h_theta ** dim_h_range )
590620 w_spatial_freqs = 1.0 / (w_theta ** dim_w_range )
591621 temporal_freqs = 1.0 / (t_theta ** dim_t_range )
592622
623+ print (f"[FASTVIDEO ROPE FORWARD] h_spatial_freqs.shape={ h_spatial_freqs .shape } , w_spatial_freqs.shape={ w_spatial_freqs .shape } , temporal_freqs.shape={ temporal_freqs .shape } " )
624+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
625+ f .write (f"[FASTVIDEO ROPE FORWARD] h_spatial_freqs.shape={ h_spatial_freqs .shape } , w_spatial_freqs.shape={ w_spatial_freqs .shape } , temporal_freqs.shape={ temporal_freqs .shape } \n " )
626+
593627 emb_h = torch .outer (seq [:pe_size [1 ]],
594628 h_spatial_freqs )[None , :, None , :].repeat (
595629 pe_size [0 ], 1 , pe_size [2 ], 1 )
@@ -600,10 +634,16 @@ def forward(self,
600634 # Apply sequence scaling in temporal dimension
601635 if fps is None :
602636 # Images
637+ print (f"[FASTVIDEO ROPE FORWARD] Using image mode (fps=None)" )
603638 emb_t = torch .outer (seq [:pe_size [0 ]], temporal_freqs )
639+ with open ("/workspace/FastVideo/fastvideo_hidden_states.log" , "a" ) as f :
640+ f .write (f"[FASTVIDEO ROPE FORWARD] Using image mode (fps=None)\n " )
604641 else :
605642 # Videos
606- emb_t = torch .outer (seq [:pe_size [0 ]] / fps * self .base_fps ,
643+ print (f"[FASTVIDEO ROPE FORWARD] Using video mode (fps={ fps } )" )
644+ temporal_scale = seq [:pe_size [0 ]] / fps * self .base_fps
645+ print (f"[FASTVIDEO ROPE FORWARD] temporal_scale range: { temporal_scale .min ().item ():.6f} to { temporal_scale .max ().item ():.6f} " )
646+ emb_t = torch .outer (temporal_scale ,
607647 temporal_freqs )
608648
609649 emb_t = emb_t [:, None , None , :].repeat (1 , pe_size [1 ], pe_size [2 ], 1 )
0 commit comments