@@ -115,46 +115,63 @@ def __init__(
115115        self .theta  =  theta 
116116        self ._causal_rope_fix  =  _causal_rope_fix 
117117
118-     def  forward (
119-         self ,
120-         hidden_states : torch .Tensor ,
121-         num_frames : int ,
122-         height : int ,
123-         width : int ,
124-         frame_rate : Optional [int ] =  None ,
125-         rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] =  None ,
126-     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
127-         batch_size  =  hidden_states .size (0 )
128- 
118+     
119+     def  _prepare_video_coords (self , batch_size : int , num_frames : int , height : int , width : int , rope_interpolation_scale : Tuple [torch .Tensor , float , float ], device : torch .device ) ->  torch .Tensor :
129120        # Always compute rope in fp32 
130-         grid_h  =  torch .arange (height , dtype = torch .float32 , device = hidden_states . device )
131-         grid_w  =  torch .arange (width , dtype = torch .float32 , device = hidden_states . device )
132-         grid_f  =  torch .arange (num_frames , dtype = torch .float32 , device = hidden_states . device )
121+         grid_h  =  torch .arange (height , dtype = torch .float32 , device = device )
122+         grid_w  =  torch .arange (width , dtype = torch .float32 , device = device )
123+         grid_f  =  torch .arange (num_frames , dtype = torch .float32 , device = device )
133124        grid  =  torch .meshgrid (grid_f , grid_h , grid_w , indexing = "ij" )
134125        grid  =  torch .stack (grid , dim = 0 )
135126        grid  =  grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
136127
137-         if  rope_interpolation_scale  is  not None :
138-             if  isinstance (rope_interpolation_scale , tuple ):
139-                 # This will be deprecated in v0.34.0 
140-                 grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
141-                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
142-                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
128+         if  isinstance (rope_interpolation_scale , tuple ):
129+             # This will be deprecated in v0.34.0 
130+             grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
131+             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
132+             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
133+         else :
134+             if  not  self ._causal_rope_fix :
135+                 grid [:, 0 :1 ] =  (
136+                     grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
137+                 )
143138            else :
144-                 if  not  self ._causal_rope_fix :
145-                     grid [:, 0 :1 ] =  (
146-                         grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
147-                     )
148-                 else :
149-                     grid [:, 0 :1 ] =  (
150-                         ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
151-                         *  self .patch_size_t 
152-                         /  self .base_num_frames 
153-                     )
154-                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
155-                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
139+                 grid [:, 0 :1 ] =  (
140+                     ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
141+                     *  self .patch_size_t 
142+                     /  self .base_num_frames 
143+                 )
144+             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
145+             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
156146
157147        grid  =  grid .flatten (2 , 4 ).transpose (1 , 2 )
148+         
149+         return  grid 
150+     
151+ 
152+     def  forward (
153+         self ,
154+         hidden_states : torch .Tensor ,
155+         num_frames : Optional [int ] =  None ,
156+         height : Optional [int ] =  None ,
157+         width : Optional [int ] =  None ,
158+         frame_rate : Optional [int ] =  None ,
159+         rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] =  None ,
160+         video_coords : Optional [torch .Tensor ] =  None ,
161+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
162+         batch_size  =  hidden_states .size (0 )
163+ 
164+         if  video_coords  is  None :
165+             grid  =  self ._prepare_video_coords (batch_size , num_frames , height , width , rope_interpolation_scale = rope_interpolation_scale , device = hidden_states .device )
166+         else :
167+             grid  =  torch .stack (
168+                 [
169+                     video_coords [:, 0 ] /  self .base_num_frames , 
170+                     video_coords [:, 1 ] /  self .base_height , 
171+                     video_coords [:, 2 ] /  self .base_width 
172+                 ], 
173+                 dim = - 1 ,
174+             )
158175
159176        start  =  1.0 
160177        end  =  self .theta 
@@ -387,11 +404,12 @@ def forward(
387404        encoder_hidden_states : torch .Tensor ,
388405        timestep : torch .LongTensor ,
389406        encoder_attention_mask : torch .Tensor ,
390-         num_frames : int ,
391-         height : int ,
392-         width : int ,
393-         frame_rate : int ,
407+         num_frames : Optional [ int ]  =   None ,
408+         height : Optional [ int ]  =   None ,
409+         width : Optional [ int ]  =   None ,
410+         frame_rate : Optional [ int ]  =   None ,
394411        rope_interpolation_scale : Optional [Union [Tuple [float , float , float ], torch .Tensor ]] =  None ,
412+         video_coords : Optional [torch .Tensor ] =  None ,
395413        attention_kwargs : Optional [Dict [str , Any ]] =  None ,
396414        return_dict : bool  =  True ,
397415    ) ->  torch .Tensor :
@@ -414,7 +432,8 @@ def forward(
414432            msg  =  "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." 
415433            deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
416434
417-         image_rotary_emb  =  self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale )
435+ 
436+         image_rotary_emb  =  self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords )
418437
419438        # convert encoder_attention_mask to a bias the same way we do for attention_mask 
420439        if  encoder_attention_mask  is  not None  and  encoder_attention_mask .ndim  ==  2 :
@@ -475,5 +494,6 @@ def apply_rotary_emb(x, freqs):
475494    cos , sin  =  freqs 
476495    x_real , x_imag  =  x .unflatten (2 , (- 1 , 2 )).unbind (- 1 )  # [B, S, H, D // 2] 
477496    x_rotated  =  torch .stack ([- x_imag , x_real ], dim = - 1 ).flatten (2 )
478-     out  =  (x .float () *  cos  +  x_rotated .float () *  sin ).to (x .dtype )
497+     # YiYi TODO: testing only, remove this change before merging 
498+     out  =  (x  *  cos .to (x .dtype ) +  x_rotated  *  sin .to (x .dtype )).to (x .dtype )
479499    return  out 
0 commit comments