Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/community/pipeline_flux_semantic_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
Expand Down
29 changes: 17 additions & 12 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,23 @@ def prepare_ip_adapter_image_embeds(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)

image_embeds.append(single_image_embeds[None, :])
else:
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]

if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
Expand Down Expand Up @@ -868,19 +873,19 @@ def __call__(
else:
guidance = None

# TODO: Clarify this section
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if isinstance(ip_adapter_image, list):
negative_ip_adapter_image = [negative_ip_adapter_image] * len(ip_adapter_image)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if isinstance(negative_ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image] * len(negative_ip_adapter_image)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
Expand Down