@@ -636,6 +636,11 @@ def __call__(
636636        if  self .attention_kwargs  is  None :
637637            self ._attention_kwargs  =  {}
638638
639+         txt_seq_lens  =  prompt_embeds_mask .sum (dim = 1 ).tolist () if  prompt_embeds_mask  is  not   None  else  None 
640+         negative_txt_seq_lens  =  (
641+             negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if  negative_prompt_embeds_mask  is  not   None  else  None 
642+         )
643+ 
639644        # 6. Denoising loop 
640645        self .scheduler .set_begin_index (0 )
641646        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
@@ -654,7 +659,7 @@ def __call__(
654659                        encoder_hidden_states_mask = prompt_embeds_mask ,
655660                        encoder_hidden_states = prompt_embeds ,
656661                        img_shapes = img_shapes ,
657-                         txt_seq_lens = prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
662+                         txt_seq_lens = txt_seq_lens ,
658663                        attention_kwargs = self .attention_kwargs ,
659664                        return_dict = False ,
660665                    )[0 ]
@@ -668,7 +673,7 @@ def __call__(
668673                            encoder_hidden_states_mask = negative_prompt_embeds_mask ,
669674                            encoder_hidden_states = negative_prompt_embeds ,
670675                            img_shapes = img_shapes ,
671-                             txt_seq_lens = negative_prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
676+                             txt_seq_lens = negative_txt_seq_lens ,
672677                            attention_kwargs = self .attention_kwargs ,
673678                            return_dict = False ,
674679                        )[0 ]
0 commit comments