1313# limitations under the License. 
1414
1515
16+ import  functools 
1617import  math 
1718from  typing  import  Any , Dict , List , Optional , Tuple , Union 
1819
@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163        self .axes_dim  =  axes_dim 
163164        pos_index  =  torch .arange (1024 )
164165        neg_index  =  torch .arange (1024 ).flip (0 ) *  - 1  -  1 
165-         self . pos_freqs  =  torch .cat (
166+         pos_freqs  =  torch .cat (
166167            [
167168                self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
168169                self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
169170                self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
170171            ],
171172            dim = 1 ,
172173        )
173-         self . neg_freqs  =  torch .cat (
174+         neg_freqs  =  torch .cat (
174175            [
175176                self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
176177                self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180            dim = 1 ,
180181        )
181182        self .rope_cache  =  {}
183+         self .register_buffer ("pos_freqs" , pos_freqs , persistent = False )
184+         self .register_buffer ("neg_freqs" , neg_freqs , persistent = False )
182185
183186        # 是否使用 scale rope 
184187        self .scale_rope  =  scale_rope 
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201        Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: 
199202        txt_length: [bs] a list of 1 integers representing the length of the text 
200203        """ 
201-         if  self .pos_freqs .device  !=  device :
202-             self .pos_freqs  =  self .pos_freqs .to (device )
203-             self .neg_freqs  =  self .neg_freqs .to (device )
204- 
205204        if  isinstance (video_fhw , list ):
206205            video_fhw  =  video_fhw [0 ]
207206        frame , height , width  =  video_fhw 
208207        rope_key  =  f"{ frame } { height } { width }  
209208
210-         if  rope_key  not  in self .rope_cache :
211-             seq_lens  =  frame  *  height  *  width 
212-             freqs_pos  =  self .pos_freqs .split ([x  //  2  for  x  in  self .axes_dim ], dim = 1 )
213-             freqs_neg  =  self .neg_freqs .split ([x  //  2  for  x  in  self .axes_dim ], dim = 1 )
214-             freqs_frame  =  freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
215-             if  self .scale_rope :
216-                 freqs_height  =  torch .cat ([freqs_neg [1 ][- (height  -  height  //  2 ) :], freqs_pos [1 ][: height  //  2 ]], dim = 0 )
217-                 freqs_height  =  freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
218-                 freqs_width  =  torch .cat ([freqs_neg [2 ][- (width  -  width  //  2 ) :], freqs_pos [2 ][: width  //  2 ]], dim = 0 )
219-                 freqs_width  =  freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
220- 
221-             else :
222-                 freqs_height  =  freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
223-                 freqs_width  =  freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
224- 
225-             freqs  =  torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
226-             self .rope_cache [rope_key ] =  freqs .clone ().contiguous ()
227-         vid_freqs  =  self .rope_cache [rope_key ]
209+         if  not  torch .compiler .is_compiling ():
210+             if  rope_key  not  in self .rope_cache :
211+                 self .rope_cache [rope_key ] =  self ._compute_video_freqs (frame , height , width )
212+             vid_freqs  =  self .rope_cache [rope_key ]
213+         else :
214+             vid_freqs  =  self ._compute_video_freqs (frame , height , width )
228215
229216        if  self .scale_rope :
230217            max_vid_index  =  max (height  //  2 , width  //  2 )
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223
237224        return  vid_freqs , txt_freqs 
238225
226+     @functools .lru_cache (maxsize = None ) 
227+     def  _compute_video_freqs (self , frame , height , width ):
228+         seq_lens  =  frame  *  height  *  width 
229+         freqs_pos  =  self .pos_freqs .split ([x  //  2  for  x  in  self .axes_dim ], dim = 1 )
230+         freqs_neg  =  self .neg_freqs .split ([x  //  2  for  x  in  self .axes_dim ], dim = 1 )
231+ 
232+         freqs_frame  =  freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
233+         if  self .scale_rope :
234+             freqs_height  =  torch .cat ([freqs_neg [1 ][- (height  -  height  //  2 ) :], freqs_pos [1 ][: height  //  2 ]], dim = 0 )
235+             freqs_height  =  freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
236+             freqs_width  =  torch .cat ([freqs_neg [2 ][- (width  -  width  //  2 ) :], freqs_pos [2 ][: width  //  2 ]], dim = 0 )
237+             freqs_width  =  freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
238+         else :
239+             freqs_height  =  freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
240+             freqs_width  =  freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
241+ 
242+         freqs  =  torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
243+         return  freqs .clone ().contiguous ()
244+ 
239245
240246class  QwenDoubleStreamAttnProcessor2_0 :
241247    """ 
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488    _supports_gradient_checkpointing  =  True 
483489    _no_split_modules  =  ["QwenImageTransformerBlock" ]
484490    _skip_layerwise_casting_patterns  =  ["pos_embed" , "norm" ]
491+     _repeated_blocks  =  ["QwenImageTransformerBlock" ]
485492
486493    @register_to_config  
487494    def  __init__ (
0 commit comments