Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 58 additions & 9 deletions src/diffusers/pipelines/flux/pipeline_flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL
import torch
from transformers import (
CLIPTextModel,
Expand Down Expand Up @@ -389,10 +390,31 @@ def encode_prompt(

return prompt_embeds, pooled_prompt_embeds, text_ids

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so. I think this method was made before we introduced image processor, which set a standard image input format we accept across all our pipelines and check if it is a valid format there

if not is_valid_image_imagelist(image):

def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
if image_is_pil:
image_batch_size = 1
else:
image_batch_size = len(image)

if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]

if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def check_inputs(
self,
prompt,
prompt_2,
image,
height,
width,
prompt_embeds=None,
Expand Down Expand Up @@ -429,6 +451,30 @@ def check_inputs(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")

if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove almost all of the checks and only do following two things:

  1. make sure if it is a multi controlnet, image is a list with same length of number of controlnet
  2. add a check here
        if image_batch_size == 1:
            repeat_by = batch_size
        elif image_batch_size == batch_size:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt
        else:
             raise ValueError("...")

it should be sufficient no? would we miss anything here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first question, I already added this checking in an earlier commit, see:

https://github.com/christopher-beckham/diffusers/blob/flux_controlnet_input_checking/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L458-L474

For (2), that if elif statement will break everything since the batch_size that is actually passed into prepare_image outside of it is actually batch_size*num_images_per_prompt, i.e:

control_image = self.prepare_image(
    ...
    batch_size=batch_size * num_images_per_prompt,
    num_images_per_prompt=num_images_per_prompt,
    ...
)

It's a little confusing to parse (esp since we also pass num_images_per_prompt into that method) so I changed it to the following:

control_image = self.prepare_image(
    ...
    batch_size=batch_size,
    num_images_per_prompt=num_images_per_prompt,
    ...
)

and made an adjustment to the code inside that method, so now we have:

if image_batch_size == 1:
    repeat_by = batch_size*num_images_per_prompt
elif image_batch_size == batch_size:
    # image batch size is the same as prompt batch size
    repeat_by = num_images_per_prompt
else:
    raise ValueError(...)

I wrote an informative ValueError as well in the event the else statement gets tripped.

I'll push these changes momentarily.

isinstance(self.controlnet, FluxControlNetModel)
):
self.check_image(image, prompt, prompt_embeds)
elif (
isinstance(self.controlnet, FluxMultiControlNetModel)
):
if not isinstance(image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

# When `image` is a nested list:
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
elif any(isinstance(i, list) for i in image):
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
elif len(image) != len(self.controlnet.nets):
raise ValueError(
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
)

for image_ in image:
self.check_image(image_, prompt, prompt_embeds)
else:
assert False

if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
Expand Down Expand Up @@ -523,18 +569,20 @@ def prepare_image(
do_classifier_free_guidance=False,
guess_mode=False,
):
if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)


image = self.image_processor.preprocess(image, height=height, width=width)
image_batch_size = image.shape[0]

if image_batch_size == 1:
repeat_by = batch_size
else:
repeat_by = batch_size*num_images_per_prompt
elif image_batch_size == batch_size:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
else:
raise ValueError(
"`image_batch_size` must be either 1 or equal to the prompt " + \
f"batch size, which is {batch_size}."
)

image = image.repeat_interleave(repeat_by, dim=0)

Expand Down Expand Up @@ -678,6 +726,7 @@ def __call__(
self.check_inputs(
prompt,
prompt_2,
control_image,
height,
width,
prompt_embeds=prompt_embeds,
Expand Down Expand Up @@ -726,7 +775,7 @@ def __call__(
image=control_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
Expand Down Expand Up @@ -762,7 +811,7 @@ def __call__(
image=control_image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
batch_size=batch_size,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.vae.dtype,
Expand Down