@@ -160,24 +160,26 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
160
160
super ().__init__ ()
161
161
self .theta = theta
162
162
self .axes_dim = axes_dim
163
- pos_index = torch .arange (1024 )
164
- neg_index = torch .arange (1024 ).flip (0 ) * - 1 - 1
165
- self .pos_freqs = torch .cat (
163
+ # Initialize with default size 1024, but allow dynamic expansion
164
+ self ._current_max_len = 1024
165
+ pos_index = torch .arange (self ._current_max_len )
166
+ neg_index = torch .arange (self ._current_max_len ).flip (0 ) * - 1 - 1
167
+ self .register_buffer ('pos_freqs' , torch .cat (
166
168
[
167
169
self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
168
170
self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
169
171
self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
170
172
],
171
173
dim = 1 ,
172
- )
173
- self .neg_freqs = torch .cat (
174
+ ))
175
+ self .register_buffer ( ' neg_freqs' , torch .cat (
174
176
[
175
177
self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
176
178
self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
177
179
self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
178
180
],
179
181
dim = 1 ,
180
- )
182
+ ))
181
183
self .rope_cache = {}
182
184
183
185
# 是否使用 scale rope
@@ -193,6 +195,45 @@ def rope_params(self, index, dim, theta=10000):
193
195
freqs = torch .polar (torch .ones_like (freqs ), freqs )
194
196
return freqs
195
197
198
+ def _expand_pos_freqs_if_needed (self , required_len ):
199
+ """Expand pos_freqs and neg_freqs if required length exceeds current size"""
200
+ if required_len <= self ._current_max_len :
201
+ return
202
+
203
+ # Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
204
+ new_max_len = max (required_len , int ((required_len + 511 ) // 512 ) * 512 )
205
+
206
+ # Generate expanded indices
207
+ pos_index = torch .arange (new_max_len , device = self .pos_freqs .device )
208
+ neg_index = torch .arange (new_max_len , device = self .neg_freqs .device ).flip (0 ) * - 1 - 1
209
+
210
+ # Generate expanded frequency embeddings
211
+ new_pos_freqs = torch .cat (
212
+ [
213
+ self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
214
+ self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
215
+ self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
216
+ ],
217
+ dim = 1 ,
218
+ ).to (device = self .pos_freqs .device , dtype = self .pos_freqs .dtype )
219
+
220
+ new_neg_freqs = torch .cat (
221
+ [
222
+ self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
223
+ self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
224
+ self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
225
+ ],
226
+ dim = 1 ,
227
+ ).to (device = self .neg_freqs .device , dtype = self .neg_freqs .dtype )
228
+
229
+ # Update buffers
230
+ self .register_buffer ('pos_freqs' , new_pos_freqs )
231
+ self .register_buffer ('neg_freqs' , new_neg_freqs )
232
+ self ._current_max_len = new_max_len
233
+
234
+ # Clear cache since dimensions changed
235
+ self .rope_cache = {}
236
+
196
237
def forward (self , video_fhw , txt_seq_lens , device ):
197
238
"""
198
239
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
@@ -232,6 +273,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
232
273
max_vid_index = max (height , width )
233
274
234
275
max_len = max (txt_seq_lens )
276
+
277
+ # Expand pos_freqs if needed to accommodate max_vid_index + max_len
278
+ required_len = max_vid_index + max_len
279
+ self ._expand_pos_freqs_if_needed (required_len )
280
+
235
281
txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_len , ...]
236
282
237
283
return vid_freqs , txt_freqs
0 commit comments