Skip to content

Commit 9e9b0f8

Browse files
committed
Changes for ipa image embeds
1 parent 339c0b4 commit 9e9b0f8

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

examples/community/pipeline_flux_semantic_guidance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds(
537537
if not isinstance(ip_adapter_image, list):
538538
ip_adapter_image = [ip_adapter_image]
539539

540-
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
540+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
541541
raise ValueError(
542-
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
542+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
543543
)
544544

545545
for single_ip_adapter_image, image_proj_layer in zip(

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -410,18 +410,23 @@ def prepare_ip_adapter_image_embeds(
410410
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
411411
)
412412

413-
for single_ip_adapter_image, image_proj_layer in zip(
414-
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
415-
):
413+
for single_ip_adapter_image in ip_adapter_image:
416414
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
417-
418415
image_embeds.append(single_image_embeds[None, :])
419416
else:
417+
if not isinstance(ip_adapter_image_embeds, list):
418+
ip_adapter_image_embeds = [ip_adapter_image_embeds]
419+
420+
if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
421+
raise ValueError(
422+
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
423+
)
424+
420425
for single_image_embeds in ip_adapter_image_embeds:
421426
image_embeds.append(single_image_embeds)
422427

423428
ip_adapter_image_embeds = []
424-
for i, single_image_embeds in enumerate(image_embeds):
429+
for single_image_embeds in image_embeds:
425430
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
426431
single_image_embeds = single_image_embeds.to(device=device)
427432
ip_adapter_image_embeds.append(single_image_embeds)
@@ -868,19 +873,19 @@ def __call__(
868873
else:
869874
guidance = None
870875

871-
# TODO: Clarify this section
872876
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
873877
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
874878
):
875-
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
876-
if isinstance(ip_adapter_image, list):
877-
negative_ip_adapter_image = [negative_ip_adapter_image] * len(ip_adapter_image)
879+
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
880+
negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
881+
negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters
882+
878883
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
879884
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
880885
):
881-
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
882-
if isinstance(negative_ip_adapter_image, list):
883-
ip_adapter_image = [ip_adapter_image] * len(negative_ip_adapter_image)
886+
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
887+
ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
888+
ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters
884889

885890
if self.joint_attention_kwargs is None:
886891
self._joint_attention_kwargs = {}

0 commit comments

Comments
 (0)