@@ -221,15 +221,24 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
221221 ip_adapter_image = data .ip_adapter_image ,
222222 ip_adapter_image_embeds = None ,
223223 device = data .device ,
224- num_images_per_prompt = 1 ,
225- do_classifier_free_guidance = data .do_classifier_free_guidance ,
226224 )
225+
227226 if data .do_classifier_free_guidance :
228- data .negative_ip_adapter_embeds = []
229- for i , image_embeds in enumerate (data .ip_adapter_embeds ):
230- negative_image_embeds , image_embeds = image_embeds .chunk (2 )
231- data .negative_ip_adapter_embeds .append (negative_image_embeds )
232- data .ip_adapter_embeds [i ] = image_embeds
227+ output_hidden_states = [not isinstance (image_proj_layer , ImageProjection ) for image_proj_layer in pipeline .unet .encoder_hid_proj .image_projection_layers ]
228+ negative_ip_adapter_embeds = []
229+ for (idx , output_hidden_state ), ip_adapter_embeds in zip (enumerate (output_hidden_states ), data .ip_adapter_embeds ):
230+ if not output_hidden_state :
231+ negative_ip_adapter_embed = torch .zeros_like (ip_adapter_embeds )
232+ else :
233+ ip_adapter_image = data .ip_adapter_image [idx ] if isinstance (data .ip_adapter_image , list ) else data .ip_adapter_image
234+ ip_adapter_image = pipeline .feature_extractor (ip_adapter_image , return_tensors = "pt" ).pixel_values
235+ negative_ip_adapter_embed = pipeline .prepare_ip_adapter_image_embeds (
236+ ip_adapter_image = ip_adapter_image ,
237+ ip_adapter_image_embeds = None ,
238+ device = data .device ,
239+ )
240+ negative_ip_adapter_embeds .append (negative_ip_adapter_embed )
241+ data .negative_ip_adapter_embeds = negative_ip_adapter_embeds
233242
234243 self .add_block_state (state , data )
235244 return pipeline , state
@@ -333,24 +342,33 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
333342 )
334343 (
335344 data .prompt_embeds ,
336- data .negative_prompt_embeds ,
337345 data .pooled_prompt_embeds ,
338- data .negative_pooled_prompt_embeds ,
339346 ) = pipeline .encode_prompt (
340347 data .prompt ,
341348 data .prompt_2 ,
342349 data .device ,
343- 1 ,
344- data .do_classifier_free_guidance ,
345- data .negative_prompt ,
346- data .negative_prompt_2 ,
347350 prompt_embeds = None ,
348- negative_prompt_embeds = None ,
349351 pooled_prompt_embeds = None ,
350- negative_pooled_prompt_embeds = None ,
351352 lora_scale = data .text_encoder_lora_scale ,
352353 clip_skip = data .clip_skip ,
353354 )
355+ zero_out_negative_prompt = data .negative_prompt is None and self .configs .get ('force_zeros_for_empty_prompt' , False )
356+ if data .do_classifier_free_guidance and zero_out_negative_prompt :
357+ data .negative_prompt_embeds = torch .zeros_like (data .prompt_embeds )
358+ data .negative_pooled_prompt_embeds = torch .zeros_like (data .pooled_prompt_embeds )
359+ elif data .do_classifier_free_guidance and not zero_out_negative_prompt :
360+ (
361+ data .negative_prompt_embeds ,
362+ data .negative_pooled_prompt_embeds ,
363+ ) = pipeline .encode_prompt (
364+ data .negative_prompt ,
365+ data .negative_prompt_2 ,
366+ data .device ,
367+ prompt_embeds = None ,
368+ pooled_prompt_embeds = None ,
369+ lora_scale = data .text_encoder_lora_scale ,
370+ clip_skip = data .clip_skip ,
371+ )
354372 # Add outputs
355373 self .add_block_state (state , data )
356374 return pipeline , state
@@ -3197,8 +3215,7 @@ def _get_add_time_ids_img2img(
31973215
31983216 return add_time_ids , add_neg_time_ids
31993217
3200- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
3201- def encode_image (self , image , device , num_images_per_prompt , output_hidden_states = None ):
3218+ def encode_image (self , image , device , output_hidden_states = None ):
32023219 dtype = next (self .image_encoder .parameters ()).dtype
32033220
32043221 if not isinstance (image , torch .Tensor ):
@@ -3207,20 +3224,10 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
32073224 image = image .to (device = device , dtype = dtype )
32083225 if output_hidden_states :
32093226 image_enc_hidden_states = self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
3210- image_enc_hidden_states = image_enc_hidden_states .repeat_interleave (num_images_per_prompt , dim = 0 )
3211- uncond_image_enc_hidden_states = self .image_encoder (
3212- torch .zeros_like (image ), output_hidden_states = True
3213- ).hidden_states [- 2 ]
3214- uncond_image_enc_hidden_states = uncond_image_enc_hidden_states .repeat_interleave (
3215- num_images_per_prompt , dim = 0
3216- )
3217- return image_enc_hidden_states , uncond_image_enc_hidden_states
3227+ return image_enc_hidden_states
32183228 else :
32193229 image_embeds = self .image_encoder (image ).image_embeds
3220- image_embeds = image_embeds .repeat_interleave (num_images_per_prompt , dim = 0 )
3221- uncond_image_embeds = torch .zeros_like (image_embeds )
3222-
3223- return image_embeds , uncond_image_embeds
3230+ return image_embeds
32243231
32253232 # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
32263233 # 1. return image without apply any guidance
@@ -3254,20 +3261,13 @@ def prepare_control_image(
32543261
32553262 return image
32563263
3257- # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
32583264 def encode_prompt (
32593265 self ,
32603266 prompt : str ,
32613267 prompt_2 : Optional [str ] = None ,
32623268 device : Optional [torch .device ] = None ,
3263- num_images_per_prompt : int = 1 ,
3264- do_classifier_free_guidance : bool = True ,
3265- negative_prompt : Optional [str ] = None ,
3266- negative_prompt_2 : Optional [str ] = None ,
32673269 prompt_embeds : Optional [torch .Tensor ] = None ,
3268- negative_prompt_embeds : Optional [torch .Tensor ] = None ,
32693270 pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
3270- negative_pooled_prompt_embeds : Optional [torch .Tensor ] = None ,
32713271 lora_scale : Optional [float ] = None ,
32723272 clip_skip : Optional [int ] = None ,
32733273 ):
@@ -3282,31 +3282,12 @@ def encode_prompt(
32823282 used in both text-encoders
32833283 device: (`torch.device`):
32843284 torch device
3285- num_images_per_prompt (`int`):
3286- number of images that should be generated per prompt
3287- do_classifier_free_guidance (`bool`):
3288- whether to use classifier free guidance or not
3289- negative_prompt (`str` or `List[str]`, *optional*):
3290- The prompt or prompts not to guide the image generation. If not defined, one has to pass
3291- `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
3292- less than `1`).
3293- negative_prompt_2 (`str` or `List[str]`, *optional*):
3294- The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
3295- `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
32963285 prompt_embeds (`torch.Tensor`, *optional*):
32973286 Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
32983287 provided, text embeddings will be generated from `prompt` input argument.
3299- negative_prompt_embeds (`torch.Tensor`, *optional*):
3300- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
3301- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
3302- argument.
33033288 pooled_prompt_embeds (`torch.Tensor`, *optional*):
33043289 Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
33053290 If not provided, pooled text embeddings will be generated from `prompt` input argument.
3306- negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
3307- Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
3308- weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
3309- input argument.
33103291 lora_scale (`float`, *optional*):
33113292 A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
33123293 clip_skip (`int`, *optional*):
@@ -3391,92 +3372,11 @@ def encode_prompt(
33913372
33923373 prompt_embeds = torch .concat (prompt_embeds_list , dim = - 1 )
33933374
3394- # get unconditional embeddings for classifier free guidance
3395- zero_out_negative_prompt = negative_prompt is None and self .config .force_zeros_for_empty_prompt
3396- if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt :
3397- negative_prompt_embeds = torch .zeros_like (prompt_embeds )
3398- negative_pooled_prompt_embeds = torch .zeros_like (pooled_prompt_embeds )
3399- elif do_classifier_free_guidance and negative_prompt_embeds is None :
3400- negative_prompt = negative_prompt or ""
3401- negative_prompt_2 = negative_prompt_2 or negative_prompt
3402-
3403- # normalize str to list
3404- negative_prompt = batch_size * [negative_prompt ] if isinstance (negative_prompt , str ) else negative_prompt
3405- negative_prompt_2 = (
3406- batch_size * [negative_prompt_2 ] if isinstance (negative_prompt_2 , str ) else negative_prompt_2
3407- )
3408-
3409- uncond_tokens : List [str ]
3410- if prompt is not None and type (prompt ) is not type (negative_prompt ):
3411- raise TypeError (
3412- f"`negative_prompt` should be the same type to `prompt`, but got { type (negative_prompt )} !="
3413- f" { type (prompt )} ."
3414- )
3415- elif batch_size != len (negative_prompt ):
3416- raise ValueError (
3417- f"`negative_prompt`: { negative_prompt } has batch size { len (negative_prompt )} , but `prompt`:"
3418- f" { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches"
3419- " the batch size of `prompt`."
3420- )
3421- else :
3422- uncond_tokens = [negative_prompt , negative_prompt_2 ]
3423-
3424- negative_prompt_embeds_list = []
3425- for negative_prompt , tokenizer , text_encoder in zip (uncond_tokens , tokenizers , text_encoders ):
3426- if isinstance (self , TextualInversionLoaderMixin ):
3427- negative_prompt = self .maybe_convert_prompt (negative_prompt , tokenizer )
3428-
3429- max_length = prompt_embeds .shape [1 ]
3430- uncond_input = tokenizer (
3431- negative_prompt ,
3432- padding = "max_length" ,
3433- max_length = max_length ,
3434- truncation = True ,
3435- return_tensors = "pt" ,
3436- )
3437-
3438- negative_prompt_embeds = text_encoder (
3439- uncond_input .input_ids .to (device ),
3440- output_hidden_states = True ,
3441- )
3442- # We are only ALWAYS interested in the pooled output of the final text encoder
3443- negative_pooled_prompt_embeds = negative_prompt_embeds [0 ]
3444- negative_prompt_embeds = negative_prompt_embeds .hidden_states [- 2 ]
3445-
3446- negative_prompt_embeds_list .append (negative_prompt_embeds )
3447-
3448- negative_prompt_embeds = torch .concat (negative_prompt_embeds_list , dim = - 1 )
3449-
34503375 if self .text_encoder_2 is not None :
34513376 prompt_embeds = prompt_embeds .to (dtype = self .text_encoder_2 .dtype , device = device )
34523377 else :
34533378 prompt_embeds = prompt_embeds .to (dtype = self .unet .dtype , device = device )
34543379
3455- bs_embed , seq_len , _ = prompt_embeds .shape
3456- # duplicate text embeddings for each generation per prompt, using mps friendly method
3457- prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
3458- prompt_embeds = prompt_embeds .view (bs_embed * num_images_per_prompt , seq_len , - 1 )
3459-
3460- if do_classifier_free_guidance :
3461- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
3462- seq_len = negative_prompt_embeds .shape [1 ]
3463-
3464- if self .text_encoder_2 is not None :
3465- negative_prompt_embeds = negative_prompt_embeds .to (dtype = self .text_encoder_2 .dtype , device = device )
3466- else :
3467- negative_prompt_embeds = negative_prompt_embeds .to (dtype = self .unet .dtype , device = device )
3468-
3469- negative_prompt_embeds = negative_prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
3470- negative_prompt_embeds = negative_prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
3471-
3472- pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt ).view (
3473- bs_embed * num_images_per_prompt , - 1
3474- )
3475- if do_classifier_free_guidance :
3476- negative_pooled_prompt_embeds = negative_pooled_prompt_embeds .repeat (1 , num_images_per_prompt ).view (
3477- bs_embed * num_images_per_prompt , - 1
3478- )
3479-
34803380 if self .text_encoder is not None :
34813381 if isinstance (self , StableDiffusionXLLoraLoaderMixin ) and USE_PEFT_BACKEND :
34823382 # Retrieve the original scale by scaling back the LoRA layers
@@ -3487,16 +3387,13 @@ def encode_prompt(
34873387 # Retrieve the original scale by scaling back the LoRA layers
34883388 unscale_lora_layers (self .text_encoder_2 , lora_scale )
34893389
3490- return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
3390+ return prompt_embeds , pooled_prompt_embeds
34913391
3492- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
34933392 def prepare_ip_adapter_image_embeds (
3494- self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , do_classifier_free_guidance
3393+ self , ip_adapter_image , ip_adapter_image_embeds , device
34953394 ):
3496- image_embeds = []
3497- if do_classifier_free_guidance :
3498- negative_image_embeds = []
34993395 if ip_adapter_image_embeds is None :
3396+ ip_adapter_image_embeds = []
35003397 if not isinstance (ip_adapter_image , list ):
35013398 ip_adapter_image = [ip_adapter_image ]
35023399
@@ -3509,29 +3406,13 @@ def prepare_ip_adapter_image_embeds(
35093406 ip_adapter_image , self .unet .encoder_hid_proj .image_projection_layers
35103407 ):
35113408 output_hidden_state = not isinstance (image_proj_layer , ImageProjection )
3512- single_image_embeds , single_negative_image_embeds = self .encode_image (
3513- single_ip_adapter_image , device , 1 , output_hidden_state
3409+ single_image_embeds = self .encode_image (
3410+ single_ip_adapter_image , device , output_hidden_state
35143411 )
3412+ ip_adapter_image_embeds .append (single_image_embeds [None , :])
35153413
3516- image_embeds .append (single_image_embeds [None , :])
3517- if do_classifier_free_guidance :
3518- negative_image_embeds .append (single_negative_image_embeds [None , :])
3519- else :
3520- for single_image_embeds in ip_adapter_image_embeds :
3521- if do_classifier_free_guidance :
3522- single_negative_image_embeds , single_image_embeds = single_image_embeds .chunk (2 )
3523- negative_image_embeds .append (single_negative_image_embeds )
3524- image_embeds .append (single_image_embeds )
3525-
3526- ip_adapter_image_embeds = []
3527- for i , single_image_embeds in enumerate (image_embeds ):
3528- single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
3529- if do_classifier_free_guidance :
3530- single_negative_image_embeds = torch .cat ([negative_image_embeds [i ]] * num_images_per_prompt , dim = 0 )
3531- single_image_embeds = torch .cat ([single_negative_image_embeds , single_image_embeds ], dim = 0 )
3532-
3414+ for single_image_embeds in ip_adapter_image_embeds :
35333415 single_image_embeds = single_image_embeds .to (device = device )
3534- ip_adapter_image_embeds .append (single_image_embeds )
35353416
35363417 return ip_adapter_image_embeds
35373418
0 commit comments