@@ -636,6 +636,11 @@ def __call__(
636636 if self .attention_kwargs is None :
637637 self ._attention_kwargs = {}
638638
639+ txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist () if prompt_embeds_mask is not None else None
640+ negative_txt_seq_lens = (
641+ negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if negative_prompt_embeds_mask is not None else None
642+ )
643+
639644 # 6. Denoising loop
640645 self .scheduler .set_begin_index (0 )
641646 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -654,7 +659,7 @@ def __call__(
654659 encoder_hidden_states_mask = prompt_embeds_mask ,
655660 encoder_hidden_states = prompt_embeds ,
656661 img_shapes = img_shapes ,
657- txt_seq_lens = prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
662+ txt_seq_lens = txt_seq_lens ,
658663 attention_kwargs = self .attention_kwargs ,
659664 return_dict = False ,
660665 )[0 ]
@@ -668,7 +673,7 @@ def __call__(
668673 encoder_hidden_states_mask = negative_prompt_embeds_mask ,
669674 encoder_hidden_states = negative_prompt_embeds ,
670675 img_shapes = img_shapes ,
671- txt_seq_lens = negative_prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
676+ txt_seq_lens = negative_txt_seq_lens ,
672677 attention_kwargs = self .attention_kwargs ,
673678 return_dict = False ,
674679 )[0 ]
0 commit comments