@@ -635,6 +635,11 @@ def __call__(
635635        if  self .attention_kwargs  is  None :
636636            self ._attention_kwargs  =  {}
637637
638+         txt_seq_lens  =  prompt_embeds_mask .sum (dim = 1 ).tolist () if  prompt_embeds_mask  is  not   None  else  None 
639+         negative_txt_seq_lens  =  (
640+             negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if  negative_prompt_embeds_mask  is  not   None  else  None 
641+         )
642+ 
638643        # 6. Denoising loop 
639644        self .scheduler .set_begin_index (0 )
640645        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
@@ -653,7 +658,7 @@ def __call__(
653658                        encoder_hidden_states_mask = prompt_embeds_mask ,
654659                        encoder_hidden_states = prompt_embeds ,
655660                        img_shapes = img_shapes ,
656-                         txt_seq_lens = prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
661+                         txt_seq_lens = txt_seq_lens ,
657662                        attention_kwargs = self .attention_kwargs ,
658663                        return_dict = False ,
659664                    )[0 ]
@@ -667,7 +672,7 @@ def __call__(
667672                            encoder_hidden_states_mask = negative_prompt_embeds_mask ,
668673                            encoder_hidden_states = negative_prompt_embeds ,
669674                            img_shapes = img_shapes ,
670-                             txt_seq_lens = negative_prompt_embeds_mask . sum ( dim = 1 ). tolist () ,
675+                             txt_seq_lens = negative_txt_seq_lens ,
671676                            attention_kwargs = self .attention_kwargs ,
672677                            return_dict = False ,
673678                        )[0 ]
0 commit comments