@@ -164,22 +164,28 @@ def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
164164        self ._current_max_len  =  1024 
165165        pos_index  =  torch .arange (self ._current_max_len )
166166        neg_index  =  torch .arange (self ._current_max_len ).flip (0 ) *  - 1  -  1 
167-         self .register_buffer ('pos_freqs' , torch .cat (
168-             [
169-                 self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
170-                 self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
171-                 self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
172-             ],
173-             dim = 1 ,
174-         ))
175-         self .register_buffer ('neg_freqs' , torch .cat (
176-             [
177-                 self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
178-                 self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
179-                 self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
180-             ],
181-             dim = 1 ,
182-         ))
167+         self .register_buffer (
168+             "pos_freqs" ,
169+             torch .cat (
170+                 [
171+                     self .rope_params (pos_index , self .axes_dim [0 ], self .theta ),
172+                     self .rope_params (pos_index , self .axes_dim [1 ], self .theta ),
173+                     self .rope_params (pos_index , self .axes_dim [2 ], self .theta ),
174+                 ],
175+                 dim = 1 ,
176+             ),
177+         )
178+         self .register_buffer (
179+             "neg_freqs" ,
180+             torch .cat (
181+                 [
182+                     self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
183+                     self .rope_params (neg_index , self .axes_dim [1 ], self .theta ),
184+                     self .rope_params (neg_index , self .axes_dim [2 ], self .theta ),
185+                 ],
186+                 dim = 1 ,
187+             ),
188+         )
183189        self .rope_cache  =  {}
184190
185191        # 是否使用 scale rope 
@@ -199,22 +205,22 @@ def _expand_pos_freqs_if_needed(self, required_len):
199205        """Expand pos_freqs and neg_freqs if required length exceeds current size""" 
200206        if  required_len  <=  self ._current_max_len :
201207            return 
202-          
208+ 
203209        # Calculate new size (use next power of 2 or round to nearest 512 for efficiency) 
204210        new_max_len  =  max (required_len , int ((required_len  +  511 ) //  512 ) *  512 )
205-          
211+ 
206212        # Log warning about potential quality degradation for long prompts 
207213        if  required_len  >  512 :
208214            logger .warning (
209215                f"QwenImage model was trained on prompts up to 512 tokens. " 
210216                f"Current prompt requires { required_len }   tokens, which may lead to unpredictable behavior. " 
211217                f"Consider using shorter prompts for better results." 
212218            )
213-          
219+ 
214220        # Generate expanded indices 
215221        pos_index  =  torch .arange (new_max_len , device = self .pos_freqs .device )
216222        neg_index  =  torch .arange (new_max_len , device = self .neg_freqs .device ).flip (0 ) *  - 1  -  1 
217-          
223+ 
218224        # Generate expanded frequency embeddings 
219225        new_pos_freqs  =  torch .cat (
220226            [
@@ -224,7 +230,7 @@ def _expand_pos_freqs_if_needed(self, required_len):
224230            ],
225231            dim = 1 ,
226232        ).to (device = self .pos_freqs .device , dtype = self .pos_freqs .dtype )
227-          
233+ 
228234        new_neg_freqs  =  torch .cat (
229235            [
230236                self .rope_params (neg_index , self .axes_dim [0 ], self .theta ),
@@ -233,12 +239,12 @@ def _expand_pos_freqs_if_needed(self, required_len):
233239            ],
234240            dim = 1 ,
235241        ).to (device = self .neg_freqs .device , dtype = self .neg_freqs .dtype )
236-          
242+ 
237243        # Update buffers 
238-         self .register_buffer (' pos_freqs'  , new_pos_freqs )
239-         self .register_buffer (' neg_freqs'  , new_neg_freqs )
244+         self .register_buffer (" pos_freqs"  , new_pos_freqs )
245+         self .register_buffer (" neg_freqs"  , new_neg_freqs )
240246        self ._current_max_len  =  new_max_len 
241-          
247+ 
242248        # Clear cache since dimensions changed 
243249        self .rope_cache  =  {}
244250
@@ -281,11 +287,11 @@ def forward(self, video_fhw, txt_seq_lens, device):
281287            max_vid_index  =  max (height , width )
282288
283289        max_len  =  max (txt_seq_lens )
284-          
290+ 
285291        # Expand pos_freqs if needed to accommodate max_vid_index + max_len 
286292        required_len  =  max_vid_index  +  max_len 
287293        self ._expand_pos_freqs_if_needed (required_len )
288-          
294+ 
289295        txt_freqs  =  self .pos_freqs [max_vid_index  : max_vid_index  +  max_len , ...]
290296
291297        return  vid_freqs , txt_freqs 
0 commit comments