@@ -115,47 +115,77 @@ def __init__(
115115        self .theta  =  theta 
116116        self ._causal_rope_fix  =  _causal_rope_fix 
117117
118-     def  forward (
118+     def  _prepare_video_coords (
119119        self ,
120-         hidden_states :  torch . Tensor ,
120+         batch_size :  int ,
121121        num_frames : int ,
122122        height : int ,
123123        width : int ,
124-         frame_rate : Optional [int ] =  None ,
125-         rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] =  None ,
126-     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
127-         batch_size  =  hidden_states .size (0 )
128- 
124+         rope_interpolation_scale : Tuple [torch .Tensor , float , float ],
125+         frame_rate : float ,
126+         device : torch .device ,
127+     ) ->  torch .Tensor :
129128        # Always compute rope in fp32 
130-         grid_h  =  torch .arange (height , dtype = torch .float32 , device = hidden_states . device )
131-         grid_w  =  torch .arange (width , dtype = torch .float32 , device = hidden_states . device )
132-         grid_f  =  torch .arange (num_frames , dtype = torch .float32 , device = hidden_states . device )
129+         grid_h  =  torch .arange (height , dtype = torch .float32 , device = device )
130+         grid_w  =  torch .arange (width , dtype = torch .float32 , device = device )
131+         grid_f  =  torch .arange (num_frames , dtype = torch .float32 , device = device )
133132        grid  =  torch .meshgrid (grid_f , grid_h , grid_w , indexing = "ij" )
134133        grid  =  torch .stack (grid , dim = 0 )
135134        grid  =  grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
136135
137-         if  rope_interpolation_scale  is  not None :
138-             if  isinstance (rope_interpolation_scale , tuple ):
139-                 # This will be deprecated in v0.34.0 
140-                 grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
141-                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
142-                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
136+         if  isinstance (rope_interpolation_scale , tuple ):
137+             # This will be deprecated in v0.34.0 
138+             grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
139+             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
140+             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
141+         else :
142+             if  not  self ._causal_rope_fix :
143+                 grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
143144            else :
144-                 if  not  self ._causal_rope_fix :
145-                     grid [:, 0 :1 ] =  (
146-                         grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
147-                     )
148-                 else :
149-                     grid [:, 0 :1 ] =  (
150-                         ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
151-                         *  self .patch_size_t 
152-                         /  self .base_num_frames 
153-                     )
154-                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
155-                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
145+                 grid [:, 0 :1 ] =  (
146+                     ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
147+                     *  self .patch_size_t 
148+                     /  self .base_num_frames 
149+                 )
150+             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
151+             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
156152
157153        grid  =  grid .flatten (2 , 4 ).transpose (1 , 2 )
158154
155+         return  grid 
156+ 
157+     def  forward (
158+         self ,
159+         hidden_states : torch .Tensor ,
160+         num_frames : Optional [int ] =  None ,
161+         height : Optional [int ] =  None ,
162+         width : Optional [int ] =  None ,
163+         frame_rate : Optional [int ] =  None ,
164+         rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] =  None ,
165+         video_coords : Optional [torch .Tensor ] =  None ,
166+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
167+         batch_size  =  hidden_states .size (0 )
168+ 
169+         if  video_coords  is  None :
170+             grid  =  self ._prepare_video_coords (
171+                 batch_size ,
172+                 num_frames ,
173+                 height ,
174+                 width ,
175+                 rope_interpolation_scale = rope_interpolation_scale ,
176+                 frame_rate = frame_rate ,
177+                 device = hidden_states .device ,
178+             )
179+         else :
180+             grid  =  torch .stack (
181+                 [
182+                     video_coords [:, 0 ] /  self .base_num_frames ,
183+                     video_coords [:, 1 ] /  self .base_height ,
184+                     video_coords [:, 2 ] /  self .base_width ,
185+                 ],
186+                 dim = - 1 ,
187+             )
188+ 
159189        start  =  1.0 
160190        end  =  self .theta 
161191        freqs  =  self .theta  **  torch .linspace (
@@ -387,11 +417,12 @@ def forward(
387417        encoder_hidden_states : torch .Tensor ,
388418        timestep : torch .LongTensor ,
389419        encoder_attention_mask : torch .Tensor ,
390-         num_frames : int ,
391-         height : int ,
392-         width : int ,
393-         frame_rate : int ,
420+         num_frames : Optional [ int ]  =   None ,
421+         height : Optional [ int ]  =   None ,
422+         width : Optional [ int ]  =   None ,
423+         frame_rate : Optional [ int ]  =   None ,
394424        rope_interpolation_scale : Optional [Union [Tuple [float , float , float ], torch .Tensor ]] =  None ,
425+         video_coords : Optional [torch .Tensor ] =  None ,
395426        attention_kwargs : Optional [Dict [str , Any ]] =  None ,
396427        return_dict : bool  =  True ,
397428    ) ->  torch .Tensor :
@@ -414,7 +445,9 @@ def forward(
414445            msg  =  "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." 
415446            deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
416447
417-         image_rotary_emb  =  self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale )
448+         image_rotary_emb  =  self .rope (
449+             hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords 
450+         )
418451
419452        # convert encoder_attention_mask to a bias the same way we do for attention_mask 
420453        if  encoder_attention_mask  is  not None  and  encoder_attention_mask .ndim  ==  2 :
0 commit comments