2222
2323from  ...configuration_utils  import  ConfigMixin , register_to_config 
2424from  ...loaders  import  FromOriginalModelMixin , PeftAdapterMixin 
25- from  ...utils  import  USE_PEFT_BACKEND , deprecate ,  logging , scale_lora_layers , unscale_lora_layers 
25+ from  ...utils  import  USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers 
2626from  ...utils .torch_utils  import  maybe_allow_in_graph 
2727from  ..attention  import  FeedForward 
2828from  ..attention_processor  import  Attention 
@@ -102,7 +102,6 @@ def __init__(
102102        patch_size : int  =  1 ,
103103        patch_size_t : int  =  1 ,
104104        theta : float  =  10000.0 ,
105-         _causal_rope_fix : bool  =  False ,
106105    ) ->  None :
107106        super ().__init__ ()
108107
@@ -113,7 +112,6 @@ def __init__(
113112        self .patch_size  =  patch_size 
114113        self .patch_size_t  =  patch_size_t 
115114        self .theta  =  theta 
116-         self ._causal_rope_fix  =  _causal_rope_fix 
117115
118116    def  _prepare_video_coords (
119117        self ,
@@ -133,22 +131,10 @@ def _prepare_video_coords(
133131        grid  =  torch .stack (grid , dim = 0 )
134132        grid  =  grid .unsqueeze (0 ).repeat (batch_size , 1 , 1 , 1 , 1 )
135133
136-         if  isinstance (rope_interpolation_scale , tuple ):
137-             # This will be deprecated in v0.34.0 
134+         if  rope_interpolation_scale  is  not None :
138135            grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 ] *  self .patch_size_t  /  self .base_num_frames 
139136            grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 ] *  self .patch_size  /  self .base_height 
140137            grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 ] *  self .patch_size  /  self .base_width 
141-         else :
142-             if  not  self ._causal_rope_fix :
143-                 grid [:, 0 :1 ] =  grid [:, 0 :1 ] *  rope_interpolation_scale [0 :1 ] *  self .patch_size_t  /  self .base_num_frames 
144-             else :
145-                 grid [:, 0 :1 ] =  (
146-                     ((grid [:, 0 :1 ] -  1 ) *  rope_interpolation_scale [0 :1 ] +  1  /  frame_rate ).clamp (min = 0 )
147-                     *  self .patch_size_t 
148-                     /  self .base_num_frames 
149-                 )
150-             grid [:, 1 :2 ] =  grid [:, 1 :2 ] *  rope_interpolation_scale [1 :2 ] *  self .patch_size  /  self .base_height 
151-             grid [:, 2 :3 ] =  grid [:, 2 :3 ] *  rope_interpolation_scale [2 :3 ] *  self .patch_size  /  self .base_width 
152138
153139        grid  =  grid .flatten (2 , 4 ).transpose (1 , 2 )
154140
@@ -363,7 +349,6 @@ def __init__(
363349        caption_channels : int  =  4096 ,
364350        attention_bias : bool  =  True ,
365351        attention_out_bias : bool  =  True ,
366-         _causal_rope_fix : bool  =  False ,
367352    ) ->  None :
368353        super ().__init__ ()
369354
@@ -385,7 +370,6 @@ def __init__(
385370            patch_size = patch_size ,
386371            patch_size_t = patch_size_t ,
387372            theta = 10000.0 ,
388-             _causal_rope_fix = _causal_rope_fix ,
389373        )
390374
391375        self .transformer_blocks  =  nn .ModuleList (
@@ -441,10 +425,6 @@ def forward(
441425                    "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 
442426                )
443427
444-         if  not  isinstance (rope_interpolation_scale , torch .Tensor ):
445-             msg  =  "Passing a tuple for `rope_interpolation_scale` is deprecated and will be removed in v0.34.0." 
446-             deprecate ("rope_interpolation_scale" , "0.34.0" , msg )
447- 
448428        image_rotary_emb  =  self .rope (
449429            hidden_states , num_frames , height , width , frame_rate , rope_interpolation_scale , video_coords 
450430        )
0 commit comments