1313# limitations under the License.
1414
1515
16- import functools
1716import math
1817from typing import Any , Dict , List , Optional , Tuple , Union
1918
@@ -161,17 +160,17 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
161160 super ().__init__ ()
162161 self .theta = theta
163162 self .axes_dim = axes_dim
164- pos_index = torch .arange (1024 )
165- neg_index = torch .arange (1024 ).flip (0 ) * - 1 - 1
166- pos_freqs = torch .cat (
163+ pos_index = torch .arange (4096 )
164+ neg_index = torch .arange (4096 ).flip (0 ) * - 1 - 1
165+ self . pos_freqs = torch .cat (
167166 [
168167 self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
169168 self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
170169 self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
171170 ],
172171 dim = 1 ,
173172 )
174- neg_freqs = torch .cat (
173+ self . neg_freqs = torch .cat (
175174 [
176175 self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
177176 self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
@@ -180,10 +179,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
180179 dim = 1 ,
181180 )
182181 self .rope_cache = {}
183- self .register_buffer ("pos_freqs" , pos_freqs , persistent = False )
184- self .register_buffer ("neg_freqs" , neg_freqs , persistent = False )
185182
186- # 是否使用 scale rope
183+ # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
187184 self .scale_rope = scale_rope
188185
189186 def rope_params (self , index , dim , theta = 10000 ):
@@ -201,47 +198,51 @@ def forward(self, video_fhw, txt_seq_lens, device):
201198 Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
202199 txt_length: [bs] a list of 1 integers representing the length of the text
203200 """
201+ if self .pos_freqs .device != device :
202+ self .pos_freqs = self .pos_freqs .to (device )
203+ self .neg_freqs = self .neg_freqs .to (device )
204+
204205 if isinstance (video_fhw , list ):
205206 video_fhw = video_fhw [0 ]
206- frame , height , width = video_fhw
207- rope_key = f"{ frame } _{ height } _{ width } "
208207
209- if not torch .compiler .is_compiling ():
210- if rope_key not in self .rope_cache :
211- self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width )
212- vid_freqs = self .rope_cache [rope_key ]
213- else :
214- vid_freqs = self ._compute_video_freqs (frame , height , width )
208+ vid_freqs = []
209+ max_vid_index = 0
210+ for idx , fhw in enumerate (video_fhw ):
211+ frame , height , width = fhw
212+ rope_key = f"{ idx } _{ height } _{ width } "
215213
216- if self .scale_rope :
217- max_vid_index = max (height // 2 , width // 2 )
218- else :
219- max_vid_index = max (height , width )
214+ if rope_key not in self .rope_cache :
215+ seq_lens = frame * height * width
216+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
217+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
218+ freqs_frame = freqs_pos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
219+ if self .scale_rope :
220+ freqs_height = torch .cat (
221+ [freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0
222+ )
223+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
224+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
225+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
226+
227+ else :
228+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
229+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
230+
231+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
232+ self .rope_cache [rope_key ] = freqs .clone ().contiguous ()
233+ vid_freqs .append (self .rope_cache [rope_key ])
234+
235+ if self .scale_rope :
236+ max_vid_index = max (height // 2 , width // 2 , max_vid_index )
237+ else :
238+ max_vid_index = max (height , width , max_vid_index )
220239
221240 max_len = max (txt_seq_lens )
222241 txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_len , ...]
242+ vid_freqs = torch .cat (vid_freqs , dim = 0 )
223243
224244 return vid_freqs , txt_freqs
225245
226- @functools .lru_cache (maxsize = None )
227- def _compute_video_freqs (self , frame , height , width ):
228- seq_lens = frame * height * width
229- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
230- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
231-
232- freqs_frame = freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
233- if self .scale_rope :
234- freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
235- freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
236- freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
237- freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
238- else :
239- freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
240- freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
241-
242- freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
243- return freqs .clone ().contiguous ()
244-
245246
246247class QwenDoubleStreamAttnProcessor2_0 :
247248 """
0 commit comments