@@ -790,20 +790,21 @@ def __call__(
790790 max_sequence_length = max_sequence_length ,
791791 lora_scale = lora_scale ,
792792 )
793- (
794- negative_prompt_embeds ,
795- negative_pooled_prompt_embeds ,
796- _ ,
797- ) = self .encode_prompt (
798- prompt = negative_prompt ,
799- prompt_2 = negative_prompt_2 ,
800- prompt_embeds = negative_prompt_embeds ,
801- pooled_prompt_embeds = negative_pooled_prompt_embeds ,
802- device = device ,
803- num_images_per_prompt = num_images_per_prompt ,
804- max_sequence_length = max_sequence_length ,
805- lora_scale = lora_scale ,
806- )
793+ if do_true_cfg :
794+ (
795+ negative_prompt_embeds ,
796+ negative_pooled_prompt_embeds ,
797+ _ ,
798+ ) = self .encode_prompt (
799+ prompt = negative_prompt ,
800+ prompt_2 = negative_prompt_2 ,
801+ prompt_embeds = negative_prompt_embeds ,
802+ pooled_prompt_embeds = negative_pooled_prompt_embeds ,
803+ device = device ,
804+ num_images_per_prompt = num_images_per_prompt ,
805+ max_sequence_length = max_sequence_length ,
806+ lora_scale = lora_scale ,
807+ )
807808
808809 # 4. Prepare latent variables
809810 num_channels_latents = self .transformer .config .in_channels // 4
0 commit comments