@@ -165,12 +165,7 @@ def compute_text_seq_len_from_mask(
165165 active_positions = torch .where (encoder_hidden_states_mask , position_ids , position_ids .new_zeros (()))
166166 has_active = encoder_hidden_states_mask .any (dim = 1 )
167167 per_sample_len = torch .where (has_active , active_positions .max (dim = 1 ).values + 1 , torch .as_tensor (text_seq_len ))
168-
169- # For RoPE, we use the full text_seq_len (since per_sample_len.max() <= text_seq_len always)
170- # Keep as tensor to avoid graph breaks in torch.compile
171- rope_text_seq_len = torch .tensor (text_seq_len , device = encoder_hidden_states .device , dtype = torch .long )
172-
173- return rope_text_seq_len , per_sample_len , encoder_hidden_states_mask
168+ return text_seq_len , per_sample_len , encoder_hidden_states_mask
174169
175170
176171class QwenTimestepProjEmbeddings (nn .Module ):
@@ -271,10 +266,6 @@ def forward(
271266 if max_txt_seq_len is None :
272267 raise ValueError ("Either `max_txt_seq_len` or `txt_seq_lens` (deprecated) must be provided." )
273268
274- # Move to device unconditionally to avoid graph breaks in torch.compile
275- self .pos_freqs = self .pos_freqs .to (device )
276- self .neg_freqs = self .neg_freqs .to (device )
277-
278269 # Validate batch inference with variable-sized images
279270 if isinstance (video_fhw , list ) and len (video_fhw ) > 1 :
280271 # Check if all instances have the same size
@@ -297,25 +288,29 @@ def forward(
297288 for idx , fhw in enumerate (video_fhw ):
298289 frame , height , width = fhw
299290 # RoPE frequencies are cached via a lru_cache decorator on _compute_video_freqs
300- video_freq = self ._compute_video_freqs (frame , height , width , idx )
301- video_freq = video_freq .to (device )
291+ video_freq = self ._compute_video_freqs (frame , height , width , idx , device )
302292 vid_freqs .append (video_freq )
303293
304294 if self .scale_rope :
305295 max_vid_index = max (height // 2 , width // 2 , max_vid_index )
306296 else :
307297 max_vid_index = max (height , width , max_vid_index )
308298
309- txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_txt_seq_len , ...]
299+ max_txt_seq_len_int = int (max_txt_seq_len )
300+ # Create device-specific copy for text freqs without modifying self.pos_freqs
301+ txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
310302 vid_freqs = torch .cat (vid_freqs , dim = 0 )
311303
312304 return vid_freqs , txt_freqs
313305
314306 @functools .lru_cache (maxsize = 128 )
315- def _compute_video_freqs (self , frame : int , height : int , width : int , idx : int = 0 ) -> torch .Tensor :
307+ def _compute_video_freqs (self , frame : int , height : int , width : int , idx : int = 0 , device : torch . device = None ) -> torch .Tensor :
316308 seq_lens = frame * height * width
317- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
318- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
309+ pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
310+ neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
311+
312+ freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
313+ freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
319314
320315 freqs_frame = freqs_pos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
321316 if self .scale_rope :
@@ -384,10 +379,6 @@ def forward(
384379 device: (`torch.device`, *optional*):
385380 The device on which to perform the RoPE computation.
386381 """
387- # Move to device unconditionally to avoid graph breaks in torch.compile
388- self .pos_freqs = self .pos_freqs .to (device )
389- self .neg_freqs = self .neg_freqs .to (device )
390-
391382 # Validate batch inference with variable-sized images
392383 # In Layer3DRope, the outer list represents batch, inner list/tuple represents layers
393384 if isinstance (video_fhw , list ) and len (video_fhw ) > 1 :
@@ -412,11 +403,10 @@ def forward(
412403 for idx , fhw in enumerate (video_fhw ):
413404 frame , height , width = fhw
414405 if idx != layer_num :
415- video_freq = self ._compute_video_freqs (frame , height , width , idx )
406+ video_freq = self ._compute_video_freqs (frame , height , width , idx , device )
416407 else :
417408 ### For the condition image, we set the layer index to -1
418- video_freq = self ._compute_condition_freqs (frame , height , width )
419- video_freq = video_freq .to (device )
409+ video_freq = self ._compute_condition_freqs (frame , height , width , device )
420410 vid_freqs .append (video_freq )
421411
422412 if self .scale_rope :
@@ -425,16 +415,21 @@ def forward(
425415 max_vid_index = max (height , width , max_vid_index )
426416
427417 max_vid_index = max (max_vid_index , layer_num )
428- txt_freqs = self .pos_freqs [max_vid_index : max_vid_index + max_txt_seq_len , ...]
418+ max_txt_seq_len_int = int (max_txt_seq_len )
419+ # Create device-specific copy for text freqs without modifying self.pos_freqs
420+ txt_freqs = self .pos_freqs .to (device )[max_vid_index : max_vid_index + max_txt_seq_len_int , ...]
429421 vid_freqs = torch .cat (vid_freqs , dim = 0 )
430422
431423 return vid_freqs , txt_freqs
432424
433425 @functools .lru_cache (maxsize = None )
434- def _compute_video_freqs (self , frame , height , width , idx = 0 ):
426+ def _compute_video_freqs (self , frame , height , width , idx = 0 , device : torch . device = None ):
435427 seq_lens = frame * height * width
436- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
437- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
428+ pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
429+ neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
430+
431+ freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
432+ freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
438433
439434 freqs_frame = freqs_pos [0 ][idx : idx + frame ].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
440435 if self .scale_rope :
@@ -450,10 +445,13 @@ def _compute_video_freqs(self, frame, height, width, idx=0):
450445 return freqs .clone ().contiguous ()
451446
452447 @functools .lru_cache (maxsize = None )
453- def _compute_condition_freqs (self , frame , height , width ):
448+ def _compute_condition_freqs (self , frame , height , width , device : torch . device = None ):
454449 seq_lens = frame * height * width
455- freqs_pos = self .pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
456- freqs_neg = self .neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
450+ pos_freqs = self .pos_freqs .to (device ) if device is not None else self .pos_freqs
451+ neg_freqs = self .neg_freqs .to (device ) if device is not None else self .neg_freqs
452+
453+ freqs_pos = pos_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
454+ freqs_neg = neg_freqs .split ([x // 2 for x in self .axes_dim ], dim = 1 )
457455
458456 freqs_frame = freqs_neg [0 ][- 1 :].view (frame , 1 , 1 , - 1 ).expand (frame , height , width , - 1 )
459457 if self .scale_rope :
@@ -911,8 +909,8 @@ def forward(
911909 "txt_seq_lens" ,
912910 "0.37.0" ,
913911 "Passing `txt_seq_lens` is deprecated and will be removed in version 0.37.0. "
914- "Please use `txt_seq_len ` instead (singular, not plural) . "
915- "The new parameter accepts a single int or tensor value instead of a list ." ,
912+ "Please use `encoder_hidden_states_mask ` instead. "
913+ "The mask-based approach is more flexible and supports variable-length sequences ." ,
916914 standard_warn = False ,
917915 )
918916 if attention_kwargs is not None :
0 commit comments