Skip to content

Commit ab0d904

Browse files
committed
ip_adapter image embeds now considers num_images_per_prompt
1 parent de8909a commit ab0d904

File tree

1 file changed

+15
-35
lines changed

1 file changed

+15
-35
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -695,42 +695,22 @@ def encode_image(self, image):
695695
def prepare_ip_adapter_image_embeds(
696696
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
697697
):
698-
# image_embeds = []
699-
700-
# if do_classifier_free_guidance:
701-
# negative_image_embeds = []
702-
703-
# if ip_adapter_image_embeds is None:
704-
# single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image)
705-
# image_embeds.append(single_image_embeds[None, :])
706-
707-
# if do_classifier_free_guidance:
708-
# negative_image_embeds.append(single_negative_image_embeds[None, :])
709-
# else:
710-
# for single_image_embeds in ip_adapter_image_embeds:
711-
# if do_classifier_free_guidance:
712-
# single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
713-
# negative_image_embeds.append(single_negative_image_embeds)
714-
# image_embeds.append(single_image_embeds)
715-
716-
# ip_adapter_image_embeds = []
717-
# for i, single_image_embeds in enumerate(image_embeds):
718-
# single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
719-
720-
# if do_classifier_free_guidance:
721-
# single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
722-
# single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
723-
724-
# single_image_embeds = single_image_embeds.to(device=device)
725-
# ip_adapter_image_embeds.append(single_image_embeds)
726-
727-
728-
# Single image only :/
729-
clip_image_tensor = self.feature_extractor(images=ip_adapter_image, return_tensors="pt").pixel_values
730-
clip_image_tensor = clip_image_tensor.to(device, dtype=self.dtype)
731-
clip_image_embeds = self.image_encoder(clip_image_tensor, output_hidden_states=True).hidden_states[-2]
698+
if ip_adapter_image_embeds is None:
699+
single_image_embeds, single_negative_image_embeds = self.encode_image(ip_adapter_image)
700+
else:
701+
for single_image_embeds in ip_adapter_image_embeds:
702+
if do_classifier_free_guidance:
703+
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
704+
else:
705+
single_image_embeds = ip_adapter_image_embeds
706+
707+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
732708

733-
return torch.cat([torch.zeros_like(clip_image_embeds), clip_image_embeds], dim=0)
709+
if do_classifier_free_guidance:
710+
single_negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
711+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
712+
713+
return single_image_embeds.to(device=device)
734714

735715
@torch.no_grad()
736716
@replace_example_docstring(EXAMPLE_DOC_STRING)

0 commit comments

Comments
 (0)