@@ -164,22 +164,28 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
164
164
self ._current_max_len = 1024
165
165
pos_index = torch .arange (self ._current_max_len )
166
166
neg_index = torch .arange (self ._current_max_len ).flip (0 ) * - 1 - 1
167
- self .register_buffer ('pos_freqs' , torch .cat (
168
- [
169
- self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
170
- self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
171
- self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
172
- ],
173
- dim = 1 ,
174
- ))
175
- self .register_buffer ('neg_freqs' , torch .cat (
176
- [
177
- self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
178
- self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
179
- self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
180
- ],
181
- dim = 1 ,
182
- ))
167
+ self .register_buffer (
168
+ "pos_freqs" ,
169
+ torch .cat (
170
+ [
171
+ self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
172
+ self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
173
+ self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
174
+ ],
175
+ dim = 1 ,
176
+ ),
177
+ )
178
+ self .register_buffer (
179
+ "neg_freqs" ,
180
+ torch .cat (
181
+ [
182
+ self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
183
+ self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
184
+ self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
185
+ ],
186
+ dim = 1 ,
187
+ ),
188
+ )
183
189
self .rope_cache = {}
184
190
185
191
# 是否使用 scale rope
@@ -199,22 +205,22 @@ def _expand_pos_freqs_if_needed(self, required_len):
199
205
"""Expand pos_freqs and neg_freqs if required length exceeds current size"""
200
206
if required_len <= self ._current_max_len :
201
207
return
202
-
208
+
203
209
# Calculate new size (use next power of 2 or round to nearest 512 for efficiency)
204
210
new_max_len = max (required_len , int ((required_len + 511 ) // 512 ) * 512 )
205
-
211
+
206
212
# Log warning about potential quality degradation for long prompts
207
213
if required_len > 512 :
208
214
logger .warning (
209
215
f"QwenImage model was trained on prompts up to 512 tokens. "
210
216
f"Current prompt requires { required_len } tokens, which may lead to unpredictable behavior. "
211
217
f"Consider using shorter prompts for better results."
212
218
)
213
-
219
+
214
220
# Generate expanded indices
215
221
pos_index = torch .arange (new_max_len , device = self .pos_freqs .device )
216
222
neg_index = torch .arange (new_max_len , device = self .neg_freqs .device ).flip (0 ) * - 1 - 1
217
-
223
+
218
224
# Generate expanded frequency embeddings
219
225
new_pos_freqs = torch .cat (
220
226
[
@@ -224,7 +230,7 @@ def _expand_pos_freqs_if_needed(self, required_len):
224
230
],
225
231
dim = 1 ,
226
232
).to (device = self .pos_freqs .device , dtype = self .pos_freqs .dtype )
227
-
233
+
228
234
new_neg_freqs = torch .cat (
229
235
[
230
236
self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
@@ -233,12 +239,12 @@ def _expand_pos_freqs_if_needed(self, required_len):
233
239
],
234
240
dim = 1 ,
235
241
).to (device = self .neg_freqs .device , dtype = self .neg_freqs .dtype )
236
-
242
+
237
243
# Update buffers
238
- self .register_buffer (' pos_freqs' , new_pos_freqs )
239
- self .register_buffer (' neg_freqs' , new_neg_freqs )
244
+ self .register_buffer (" pos_freqs" , new_pos_freqs )
245
+ self .register_buffer (" neg_freqs" , new_neg_freqs )
240
246
self ._current_max_len = new_max_len
241
-
247
+
242
248
# Clear cache since dimensions changed
243
249
self .rope_cache = {}
244
250
@@ -281,11 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
281
287
max_vid_index = max (height , width )
282
288
283
289
max_len = max (txt_seq_lens )
284
-
290
+
285
291
# Expand pos_freqs if needed to accommodate max_vid_index + max_len
286
292
required_len = max_vid_index + max_len
287
293
self ._expand_pos_freqs_if_needed (required_len )
288
-
294
+
289
295
txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_len , ...]
290
296
291
297
return vid_freqs , txt_freqs
0 commit comments