@@ -342,33 +342,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
342342        )
343343        (
344344            data .prompt_embeds ,
345+             data .negative_prompt_embeds ,
345346            data .pooled_prompt_embeds ,
347+             data .negative_pooled_prompt_embeds ,
346348        ) =  pipeline .encode_prompt (
347349            data .prompt ,
348350            data .prompt_2 ,
349351            data .device ,
352+             data .do_classifier_free_guidance ,
353+             data .negative_prompt ,
354+             data .negative_prompt_2 ,
350355            prompt_embeds = None ,
356+             negative_prompt_embeds = None ,
351357            pooled_prompt_embeds = None ,
358+             negative_pooled_prompt_embeds = None ,
352359            lora_scale = data .text_encoder_lora_scale ,
353360            clip_skip = data .clip_skip ,
361+             force_zeros_for_empty_prompt = self .configs .get ('force_zeros_for_empty_prompt' , False ),
354362        )
355-         zero_out_negative_prompt  =  data .negative_prompt  is  None  and  self .configs .get ('force_zeros_for_empty_prompt' , False )
356-         if  data .do_classifier_free_guidance  and  zero_out_negative_prompt :
357-             data .negative_prompt_embeds  =  torch .zeros_like (data .prompt_embeds )
358-             data .negative_pooled_prompt_embeds  =  torch .zeros_like (data .pooled_prompt_embeds )
359-         elif  data .do_classifier_free_guidance  and  not  zero_out_negative_prompt :
360-             (
361-                 data .negative_prompt_embeds ,
362-                 data .negative_pooled_prompt_embeds ,
363-             ) =  pipeline .encode_prompt (
364-                 data .negative_prompt ,
365-                 data .negative_prompt_2 ,
366-                 data .device ,
367-                 prompt_embeds = None ,
368-                 pooled_prompt_embeds = None ,
369-                 lora_scale = data .text_encoder_lora_scale ,
370-                 clip_skip = data .clip_skip ,
371-             )
372363        # Add outputs 
373364        self .add_block_state (state , data )
374365        return  pipeline , state 
@@ -3262,6 +3253,53 @@ def prepare_control_image(
32623253        return  image 
32633254
32643255    def  encode_prompt (
3256+         self ,
3257+         prompt : str ,
3258+         prompt_2 : Optional [str ] =  None ,
3259+         device : Optional [torch .device ] =  None ,
3260+         do_classifier_free_guidance : bool  =  True ,
3261+         negative_prompt : Optional [str ] =  None ,
3262+         negative_prompt_2 : Optional [str ] =  None ,
3263+         prompt_embeds : Optional [torch .Tensor ] =  None ,
3264+         negative_prompt_embeds : Optional [torch .Tensor ] =  None ,
3265+         pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
3266+         negative_pooled_prompt_embeds : Optional [torch .Tensor ] =  None ,
3267+         lora_scale : Optional [float ] =  None ,
3268+         clip_skip : Optional [int ] =  None ,
3269+         force_zeros_for_empty_prompt : bool  =  False ,
3270+     ):
3271+         (
3272+             prompt_embeds ,
3273+             pooled_prompt_embeds ,
3274+         ) =  self .encode_single_prompt (
3275+             prompt ,
3276+             prompt_2 ,
3277+             device ,
3278+             prompt_embeds = prompt_embeds ,
3279+             pooled_prompt_embeds = pooled_prompt_embeds ,
3280+             lora_scale = lora_scale ,
3281+             clip_skip = clip_skip ,
3282+         )
3283+         zero_out_negative_prompt  =  negative_prompt  is  None  and  force_zeros_for_empty_prompt 
3284+         if  do_classifier_free_guidance  and  zero_out_negative_prompt :
3285+             negative_prompt_embeds  =  torch .zeros_like (prompt_embeds )
3286+             negative_pooled_prompt_embeds  =  torch .zeros_like (pooled_prompt_embeds )
3287+         elif  do_classifier_free_guidance  and  not  zero_out_negative_prompt :
3288+             (
3289+                 negative_prompt_embeds ,
3290+                 negative_pooled_prompt_embeds ,
3291+             ) =  self .encode_single_prompt (
3292+                 negative_prompt ,
3293+                 negative_prompt_2 ,
3294+                 device ,
3295+                 prompt_embeds = negative_prompt_embeds ,
3296+                 pooled_prompt_embeds = negative_pooled_prompt_embeds ,
3297+                 lora_scale = lora_scale ,
3298+                 clip_skip = clip_skip ,
3299+             )
3300+         return  prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds 
3301+ 
3302+     def  encode_single_prompt (
32653303        self ,
32663304        prompt : str ,
32673305        prompt_2 : Optional [str ] =  None ,
0 commit comments