1414# limitations under the License. 
1515
1616import  math 
17- from  typing  import  Any , Dict , Optional , Tuple 
17+ from  typing  import  Any , Dict , Optional , Tuple ,  Union 
1818
1919import  torch 
2020import  torch .nn  as  nn 
2121import  torch .nn .functional  as  F 
2222
2323from  ...configuration_utils  import  ConfigMixin , register_to_config 
2424from  ...loaders  import  FromOriginalModelMixin , PeftAdapterMixin 
25- from  ...utils  import  USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers 
25+ from  ...utils  import  USE_PEFT_BACKEND , deprecate ,  logging , scale_lora_layers , unscale_lora_layers 
2626from  ...utils .torch_utils  import  maybe_allow_in_graph 
2727from  ..attention  import  FeedForward 
2828from  ..attention_processor  import  Attention 
@@ -102,6 +102,7 @@ def __init__(
102102        patch_size : int  =  1 ,
103103        patch_size_t : int  =  1 ,
104104        theta : float  =  10000.0 ,
105+         _causal_rope_fix : bool  =  False ,
105106    ) ->  None :
106107        super ().__init__ ()
107108
@@ -112,13 +113,15 @@ def __init__(
112113        self .patch_size  =  patch_size 
113114        self .patch_size_t  =  patch_size_t 
114115        self .theta  =  theta 
116+         self ._causal_rope_fix  =  _causal_rope_fix 
115117
116118    def  forward (
117119        self ,
118120        hidden_states : torch .Tensor ,
119121        num_frames : int ,
120122        height : int ,
121123        width : int ,
124+         frame_rate : Optional [int ] =  None ,
122125        rope_interpolation_scale : Optional [Tuple [torch .Tensor , float , float ]] =  None ,
123126    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
124127        batch_size  =  hidden_states .size (0 )
@@ -132,9 +135,24 @@ def forward(
132135        grid  =  grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
133136
134137        if  rope_interpolation_scale  is  not None :
135-             grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
136-             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
137-             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
138+             if  isinstance (rope_interpolation_scale , tuple ):
139+                 # This will be deprecated in v0.34.0 
140+                 grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
141+                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
142+                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
143+             else :
144+                 if  not  self ._causal_rope_fix :
145+                     grid [:, 0 :1 ] =  (
146+                         grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
147+                     )
148+                 else :
149+                     grid [:, 0 :1 ] =  (
150+                         ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
151+                         *  self .patch_size_t 
152+                         /  self .base_num_frames 
153+                     )
154+                 grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
155+                 grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
138156
139157        grid  =  grid .flatten (2 , 4 ).transpose (1 , 2 )
140158
@@ -315,6 +333,7 @@ def __init__(
315333        caption_channels : int  =  4096 ,
316334        attention_bias : bool  =  True ,
317335        attention_out_bias : bool  =  True ,
336+         _causal_rope_fix : bool  =  False ,
318337    ) ->  None :
319338        super ().__init__ ()
320339
@@ -336,6 +355,7 @@ def __init__(
336355            patch_size = patch_size ,
337356            patch_size_t = patch_size_t ,
338357            theta = 10000.0 ,
358+             _causal_rope_fix = _causal_rope_fix ,
339359        )
340360
341361        self .transformer_blocks  =  nn .ModuleList (
@@ -370,7 +390,8 @@ def forward(
370390        num_frames : int ,
371391        height : int ,
372392        width : int ,
373-         rope_interpolation_scale : Optional [Tuple [float , float , float ]] =  None ,
393+         frame_rate : int ,
394+         rope_interpolation_scale : Optional [Union [Tuple [float , float , float ], torch .Tensor ]] =  None ,
374395        attention_kwargs : Optional [Dict [str , Any ]] =  None ,
375396        return_dict : bool  =  True ,
376397    ) ->  torch .Tensor :
@@ -389,7 +410,11 @@ def forward(
389410                    "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 
390411                )
391412
392-         image_rotary_emb  =  self .rope (hidden_states , num_frames , height , width , rope_interpolation_scale )
413+         if  not  isinstance (rope_interpolation_scale , torch .Tensor ):
414+             msg  =  "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." 
415+             deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
416+ 
417+         image_rotary_emb  =  self .rope (hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale )
393418
394419        # convert encoder_attention_mask to a bias the same way we do for attention_mask 
395420        if  encoder_attention_mask  is  not None  and  encoder_attention_mask .ndim  ==  2 :
0 commit comments