diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index effdef465281..0ce8628c0822 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -440,23 +440,28 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): + for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - image_embeds.append(single_image_embeds[None, :]) else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): + for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 8e9991bc60e5..a56ed33c4e55 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -427,23 +427,28 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): + for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - image_embeds.append(single_image_embeds[None, :]) else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): + for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index eced1b3f09f2..43bba1c6e7c3 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -432,23 +432,28 @@ def prepare_ip_adapter_image_embeds( if not isinstance(ip_adapter_image, list): ip_adapter_image = [ip_adapter_image] - if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers): + if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters: raise ValueError( - f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters." + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." ) - for single_ip_adapter_image, image_proj_layer in zip( - ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers - ): + for single_ip_adapter_image in ip_adapter_image: single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1) - image_embeds.append(single_image_embeds[None, :]) else: + if not isinstance(ip_adapter_image_embeds, list): + ip_adapter_image_embeds = [ip_adapter_image_embeds] + + if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters: + raise ValueError( + f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters." + ) + for single_image_embeds in ip_adapter_image_embeds: image_embeds.append(single_image_embeds) ip_adapter_image_embeds = [] - for i, single_image_embeds in enumerate(image_embeds): + for single_image_embeds in image_embeds: single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = single_image_embeds.to(device=device) ip_adapter_image_embeds.append(single_image_embeds)