1313# limitations under the License.
1414
1515
16+ import functools
1617import math
1718from typing import Any , Dict , List , Optional , Tuple , Union
1819
@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162163 self .axes_dim = axes_dim
163164 pos_index = torch .arange (1024 )
164165 neg_index = torch .arange (1024 ).flip (0 ) * - 1 - 1
165- self . pos_freqs = torch .cat (
166+ pos_freqs = torch .cat (
166167 [
167168 self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
168169 self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
169170 self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
170171 ],
171172 dim = 1 ,
172173 )
173- self . neg_freqs = torch .cat (
174+ neg_freqs = torch .cat (
174175 [
175176 self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
176177 self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179180 dim = 1 ,
180181 )
181182 self .rope_cache = {}
183+ self .register_buffer ("pos_freqs" , pos_freqs , persistent = False )
184+ self .register_buffer ("neg_freqs" , neg_freqs , persistent = False )
182185
183186 # 是否使用 scale rope
184187 self .scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198201 Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199202 txt_length: [bs] a list of 1 integers representing the length of the text
200203 """
201- if self .pos_freqs .device != device :
202- self .pos_freqs = self .pos_freqs .to (device )
203- self .neg_freqs = self .neg_freqs .to (device )
204-
205204 if isinstance (video_fhw , list ):
206205 video_fhw = video_fhw [0 ]
207206 frame , height , width = video_fhw
208207 rope_key = f"{ frame } _{ height } _{ width } "
209208
210- if rope_key not in self .rope_cache :
211- seq_lens = frame * height * width
212- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
213- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
214- freqs_frame = freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
215- if self .scale_rope :
216- freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
217- freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
218- freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
219- freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
220-
221- else :
222- freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
223- freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
224-
225- freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
226- self .rope_cache [rope_key ] = freqs .clone ().contiguous ()
227- vid_freqs = self .rope_cache [rope_key ]
209+ if not torch .compiler .is_compiling ():
210+ if rope_key not in self .rope_cache :
211+ self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width )
212+ vid_freqs = self .rope_cache [rope_key ]
213+ else :
214+ vid_freqs = self ._compute_video_freqs (frame , height , width )
228215
229216 if self .scale_rope :
230217 max_vid_index = max (height // 2 , width // 2 )
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236223
237224 return vid_freqs , txt_freqs
238225
226+ @functools .lru_cache (maxsize = None )
227+ def _compute_video_freqs (self , frame , height , width ):
228+ seq_lens = frame * height * width
229+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
230+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
231+
232+ freqs_frame = freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
233+ if self .scale_rope :
234+ freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
235+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
236+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
237+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
238+ else :
239+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
240+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
241+
242+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
243+ return freqs .clone ().contiguous ()
244+
239245
240246class QwenDoubleStreamAttnProcessor2_0 :
241247 """
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482488 _supports_gradient_checkpointing = True
483489 _no_split_modules = ["QwenImageTransformerBlock" ]
484490 _skip_layerwise_casting_patterns = ["pos_embed" , "norm" ]
491+ _repeated_blocks = ["QwenImageTransformerBlock" ]
485492
486493 @register_to_config
487494 def __init__ (
0 commit comments