@@ -401,9 +401,11 @@ def encode_image(self, image, device, num_images_per_prompt):
401401 return image_embeds
402402
403403 def prepare_ip_adapter_image_embeds (
404- self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt
404+ self , ip_adapter_image , ip_adapter_image_embeds , device , num_images_per_prompt , height , width , dtype
405405 ):
406406 image_embeds = []
407+ negative_embeds = []
408+ negative_image = np .zeros ((width , height , 3 ), dtype = np .uint8 )
407409 if ip_adapter_image_embeds is None :
408410 if not isinstance (ip_adapter_image , list ):
409411 ip_adapter_image = [ip_adapter_image ]
@@ -417,19 +419,27 @@ def prepare_ip_adapter_image_embeds(
417419 ip_adapter_image , self .transformer .encoder_hid_proj .image_projection_layers
418420 ):
419421 single_image_embeds = self .encode_image (single_ip_adapter_image , device , 1 )
422+ negative_image_embeds = self .encode_image (negative_image , device , 1 )
420423
421424 image_embeds .append (single_image_embeds [None , :])
422425 image_embeds = self .transformer .encoder_hid_proj (image_embeds )
426+ negative_embeds .append (negative_image_embeds [None , :])
427+ negative_embeds = self .transformer .encoder_hid_proj (negative_embeds )
423428 else :
424429 for single_image_embeds in ip_adapter_image_embeds :
425430 image_embeds = self .transformer .encoder_hid_proj (single_image_embeds )
426431 image_embeds .append (single_image_embeds )
432+ negative_image_embeds = self .encode_image (negative_image , device , 1 )
433+ negative_embeds .append (negative_image_embeds [None , :])
434+ negative_embeds = self .transformer .encoder_hid_proj (negative_embeds )
427435
428436 ip_adapter_image_embeds = []
429- for i , single_image_embeds in enumerate (image_embeds ):
437+ for i , ( single_image_embeds , negative_image_embed ) in enumerate (zip ( image_embeds , negative_embeds ) ):
430438 single_image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
431- single_image_embeds = single_image_embeds .to (device = device )
432- ip_adapter_image_embeds .append (single_image_embeds )
439+ single_image_embeds = single_image_embeds .to (device = device , dtype = dtype )
440+ negative_image_embed = torch .cat ([negative_image_embed ] * num_images_per_prompt , dim = 0 )
441+ negative_image_embed = negative_image_embed .to (device = device , dtype = dtype )
442+ ip_adapter_image_embeds .append ((single_image_embeds , negative_image_embed ))
433443
434444 return ip_adapter_image_embeds
435445
@@ -794,6 +804,9 @@ def __call__(
794804 ip_adapter_image_embeds ,
795805 device ,
796806 batch_size * num_images_per_prompt ,
807+ height ,
808+ width ,
809+ latents .dtype ,
797810 )
798811 if self .joint_attention_kwargs is None :
799812 self ._joint_attention_kwargs = {}
0 commit comments