13
13
# limitations under the License.
14
14
15
15
16
+ import functools
16
17
import math
17
18
from typing import Any , Dict , List , Optional , Tuple , Union
18
19
@@ -162,15 +163,15 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
162
163
self .axes_dim = axes_dim
163
164
pos_index = torch .arange (1024 )
164
165
neg_index = torch .arange (1024 ).flip (0 ) * - 1 - 1
165
- self . pos_freqs = torch .cat (
166
+ pos_freqs = torch .cat (
166
167
[
167
168
self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
168
169
self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
169
170
self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
170
171
],
171
172
dim = 1 ,
172
173
)
173
- self . neg_freqs = torch .cat (
174
+ neg_freqs = torch .cat (
174
175
[
175
176
self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
176
177
self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
@@ -179,6 +180,8 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
179
180
dim = 1 ,
180
181
)
181
182
self .rope_cache = {}
183
+ self .register_buffer ("pos_freqs" , pos_freqs , persistent = False )
184
+ self .register_buffer ("neg_freqs" , neg_freqs , persistent = False )
182
185
183
186
# 是否使用 scale rope
184
187
self .scale_rope = scale_rope
@@ -198,33 +201,17 @@ def forward(self, video_fhw, txt_seq_lens, device):
198
201
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
199
202
txt_length: [bs] a list of 1 integers representing the length of the text
200
203
"""
201
- if self .pos_freqs .device != device :
202
- self .pos_freqs = self .pos_freqs .to (device )
203
- self .neg_freqs = self .neg_freqs .to (device )
204
-
205
204
if isinstance (video_fhw , list ):
206
205
video_fhw = video_fhw [0 ]
207
206
frame , height , width = video_fhw
208
207
rope_key = f"{ frame } _{ height } _{ width } "
209
208
210
- if rope_key not in self .rope_cache :
211
- seq_lens = frame * height * width
212
- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
213
- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
214
- freqs_frame = freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
215
- if self .scale_rope :
216
- freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
217
- freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
218
- freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
219
- freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
220
-
221
- else :
222
- freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
223
- freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
224
-
225
- freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
226
- self .rope_cache [rope_key ] = freqs .clone ().contiguous ()
227
- vid_freqs = self .rope_cache [rope_key ]
209
+ if not torch .compiler .is_compiling ():
210
+ if rope_key not in self .rope_cache :
211
+ self .rope_cache [rope_key ] = self ._compute_video_freqs (frame , height , width )
212
+ vid_freqs = self .rope_cache [rope_key ]
213
+ else :
214
+ vid_freqs = self ._compute_video_freqs (frame , height , width )
228
215
229
216
if self .scale_rope :
230
217
max_vid_index = max (height // 2 , width // 2 )
@@ -236,6 +223,25 @@ def forward(self, video_fhw, txt_seq_lens, device):
236
223
237
224
return vid_freqs , txt_freqs
238
225
226
+ @functools .lru_cache (maxsize = None )
227
+ def _compute_video_freqs (self , frame , height , width ):
228
+ seq_lens = frame * height * width
229
+ freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
230
+ freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
231
+
232
+ freqs_frame = freqs_pos [0 ][:frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
233
+ if self .scale_rope :
234
+ freqs_height = torch .cat ([freqs_neg [1 ][- (height - height // 2 ) :], freqs_pos [1 ][: height // 2 ]], dim = 0 )
235
+ freqs_height = freqs_height .view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
236
+ freqs_width = torch .cat ([freqs_neg [2 ][- (width - width // 2 ) :], freqs_pos [2 ][: width // 2 ]], dim = 0 )
237
+ freqs_width = freqs_width .view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
238
+ else :
239
+ freqs_height = freqs_pos [1 ][:height ].view (1 , height , 1 , - 1 ).expand (frame , height , width , - 1 )
240
+ freqs_width = freqs_pos [2 ][:width ].view (1 , 1 , width , - 1 ).expand (frame , height , width , - 1 )
241
+
242
+ freqs = torch .cat ([freqs_frame , freqs_height , freqs_width ], dim = - 1 ).reshape (seq_lens , - 1 )
243
+ return freqs .clone ().contiguous ()
244
+
239
245
240
246
class QwenDoubleStreamAttnProcessor2_0 :
241
247
"""
@@ -482,6 +488,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
482
488
_supports_gradient_checkpointing = True
483
489
_no_split_modules = ["QwenImageTransformerBlock" ]
484
490
_skip_layerwise_casting_patterns = ["pos_embed" , "norm" ]
491
+ _repeated_blocks = ["QwenImageTransformerBlock" ]
485
492
486
493
@register_to_config
487
494
def __init__ (
0 commit comments