@@ -264,22 +264,21 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
264264 batch_size , channels , height , width = hidden_states .shape
265265 p = self .patch_size
266266 post_patch_height , post_patch_width = height // p , width // p
267- num_patches = post_patch_height * post_patch_width
267+ image_seq_len = post_patch_height * post_patch_width
268268 device = hidden_states .device
269269
270- # Get caption lengths and calculate max sequence length
270+ encoder_seq_len = attention_mask . shape [ 1 ]
271271 l_effective_cap_len = attention_mask .sum (dim = 1 ).tolist ()
272- max_seq_len = max (l_effective_cap_len ) + num_patches
272+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len ]
273+ max_seq_len = max (seq_lengths )
273274
274275 # Create position IDs
275276 position_ids = torch .zeros (batch_size , max_seq_len , 3 , dtype = torch .int32 , device = device )
276277
277- for i in range (batch_size ):
278- cap_len = l_effective_cap_len [i ]
279-
278+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
280279 # Set caption positions
281- position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .int32 , device = device )
282- position_ids [i , cap_len : cap_len + num_patches , 0 ] = cap_len
280+ position_ids [i , :cap_seq_len , 0 ] = torch .arange (cap_seq_len , dtype = torch .int32 , device = device )
281+ position_ids [i , cap_seq_len : seq_len , 0 ] = cap_seq_len
283282
284283 # Set image patch positions
285284 row_ids = (
@@ -294,34 +293,33 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
294293 .repeat (post_patch_height , 1 )
295294 .flatten ()
296295 )
297- position_ids [i , cap_len : cap_len + num_patches , 1 ] = row_ids
298- position_ids [i , cap_len : cap_len + num_patches , 2 ] = col_ids
296+ position_ids [i , cap_seq_len : seq_len , 1 ] = row_ids
297+ position_ids [i , cap_seq_len : seq_len , 2 ] = col_ids
299298
300299 # Get frequencies
301300 freqs_cis = self ._get_freqs_cis (position_ids )
302301
303302 # Split frequencies for captions and images
304303 cap_freqs_cis = torch .zeros (
305- batch_size , attention_mask .shape [1 ], freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
304+ batch_size , encoder_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
305+ )
306+ img_freqs_cis = torch .zeros (
307+ batch_size , image_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
306308 )
307- img_freqs_cis = torch .zeros (batch_size , num_patches , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype )
308309
309- for i in range (batch_size ):
310- cap_len = l_effective_cap_len [i ]
311- cap_freqs_cis [i , :cap_len ] = freqs_cis [i , :cap_len ]
312- img_freqs_cis [i , :num_patches ] = freqs_cis [i , cap_len : cap_len + num_patches ]
310+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
311+ cap_freqs_cis [i , :cap_seq_len ] = freqs_cis [i , :cap_seq_len ]
312+ img_freqs_cis [i , :image_seq_len ] = freqs_cis [i , cap_seq_len :seq_len ]
313313
314314 # patch embeddings
315315 hidden_states = (
316- hidden_states .view (
317- batch_size , channels , post_patch_height , self .patch_size , post_patch_width , self .patch_size
318- )
316+ hidden_states .view (batch_size , channels , post_patch_height , p , post_patch_width , p )
319317 .permute (0 , 2 , 4 , 3 , 5 , 1 )
320318 .flatten (3 )
321319 .flatten (1 , 2 )
322320 )
323321
324- return hidden_states , freqs_cis , cap_freqs_cis , img_freqs_cis
322+ return hidden_states , cap_freqs_cis , img_freqs_cis , freqs_cis , l_effective_cap_len , seq_lengths
325323
326324
327325class Lumina2Transformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin ):
@@ -468,22 +466,17 @@ def forward(
468466 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
469467 # 1. Condition, positional & patch embedding
470468 batch_size , _ , height , width = hidden_states .shape
471- p = self .config .patch_size
472- post_patch_height , post_patch_width = height // p , width // p
473- num_patches = post_patch_height * post_patch_width
474-
475- # effective_text_seq_lengths is based on actual caption length, so it's different for each prompt in a batch
476- effective_encoder_seq_lengths = encoder_attention_mask .sum (dim = 1 ).tolist ()
477- seq_lengths = [
478- encoder_seq_len + num_patches for encoder_seq_len in effective_encoder_seq_lengths
479- ] # Add num_patches to each length
480- max_seq_len = max (seq_lengths )
481469
482470 temb , encoder_hidden_states = self .time_caption_embed (hidden_states , timestep , encoder_hidden_states )
483471
484- hidden_states , rotary_emb , context_rotary_emb , noise_rotary_emb = self .rope_embedder (
485- hidden_states , encoder_attention_mask
486- )
472+ (
473+ hidden_states ,
474+ context_rotary_emb ,
475+ noise_rotary_emb ,
476+ rotary_emb ,
477+ encoder_seq_lengths ,
478+ seq_lengths ,
479+ ) = self .rope_embedder (hidden_states , encoder_attention_mask )
487480
488481 hidden_states = self .x_embedder (hidden_states )
489482
@@ -497,12 +490,13 @@ def forward(
497490 hidden_states = layer (hidden_states , None , noise_rotary_emb , temb )
498491
499492 # 3. Joint Transformer blocks
493+ max_seq_len = max (seq_lengths )
500494 attention_mask = hidden_states .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
501495 joint_hidden_states = hidden_states .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
502- for i , (effective_encoder_seq_len , seq_len ) in enumerate (zip (effective_encoder_seq_lengths , seq_lengths )):
496+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
503497 attention_mask [i , :seq_len ] = True
504- joint_hidden_states [i , :effective_encoder_seq_len ] = encoder_hidden_states [i , :effective_encoder_seq_len ]
505- joint_hidden_states [i , effective_encoder_seq_len :seq_len ] = hidden_states [i ]
498+ joint_hidden_states [i , :encoder_seq_len ] = encoder_hidden_states [i , :encoder_seq_len ]
499+ joint_hidden_states [i , encoder_seq_len :seq_len ] = hidden_states [i ]
506500
507501 hidden_states = joint_hidden_states
508502
@@ -520,11 +514,12 @@ def forward(
520514 hidden_states = self .norm_out (hidden_states , temb )
521515
522516 # 5. Unpatchify
517+ p = self .config .patch_size
523518 output = []
524- for i , (effective_encoder_seq_len , seq_len ) in enumerate (zip (effective_encoder_seq_lengths , seq_lengths )):
519+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
525520 output .append (
526- hidden_states [i ][effective_encoder_seq_len :seq_len ]
527- .view (post_patch_height , post_patch_width , p , p , self .out_channels )
521+ hidden_states [i ][encoder_seq_len :seq_len ]
522+ .view (height // p , width // p , p , p , self .out_channels )
528523 .permute (4 , 0 , 2 , 1 , 3 )
529524 .flatten (3 , 4 )
530525 .flatten (1 , 2 )
0 commit comments