Skip to content

Conversation

@christopher-beckham
Copy link
Contributor

@christopher-beckham christopher-beckham commented Sep 28, 2024

What does this PR do?

Issue

This addresses an issue discussed in a two PRs, see #9406 (comment) and #9507 (comment)

The FLUX controlnet pipeline is actually lacking any checks for the shape or number of control images passed (for np.ndarray or torch.Tensor and PIL objects, respectively).

I will give a simple example. If you were to run the following code:

pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
# image_t is a torch tensor of shape (2,3,h,w)
self.pipe(
    prompt=["test"],
    control_image=image_t, 
    control_mode=0, 
    num_images_per_prompt=1,
    num_inference_steps=2
)

you'd get the following error:

Traceback (most recent call last):
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
    self.pipe(
  File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
    control_image = self._pack_latents(
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 458, in _pack_latents
    latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
RuntimeError: shape '[1, 16, 32, 2, 32, 2]' is invalid for input of size 131072

This is actually because the number of control images must match the number of prompts passed -- in this case we passed in a control image of batch size 2 but the number of prompts passed is 1. Because we don't catch for this, it results in a downstream error related to the packing of the latents.

It turns out SDXL's controlnet actually checks to make sure the number of control images are consistent with the number of prompts (I do recall one of the two are also allowed to be a singleton list, which is also fine). I essentially ported over the check_image method from StableDiffusionControlNetPipeline as well as modify check_inputs to actually check the control image as well. Now if you run the above code you will get the following error instead, which makes it much clearer what the issue is:

Traceback (most recent call last):
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers-tests/controlnet_pipeline_cleaner_api/flux.py", line 67, in test_torch_batched_ctrl_wrong_1ipp
    self.pipe(
  File "/home/beckhamc/envs/diffusers/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 742, in __call__
    self.check_inputs(
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 475, in check_inputs
    self.check_image(image, prompt, prompt_embeds)
  File "/network/scratch/b/beckhamc/github/diffusion-pr/diffusers/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py", line 427, in check_image
    raise ValueError(
ValueError: If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: 2, prompt batch size: 1

This fix should also work for MultiControlNet, which means you can do something like this:

multi_controlnet = FluxMultiControlNetModel([controlnet] * 2)
pipe = FluxControlNetPipeline.from_pretrained(
    base_model, vae=vae, controlnet=multi_controlnet, torch_dtype=torch.bfloat16
).to("cuda")
images = pipe(
    prompt=["1","2","3"],
    control_image=[images1, images2], 
    controlnet_conditioning_scale=[0.6, 0.6],
    control_mode=0,
    num_images_per_prompt=2
)

i.e. images and images2 are both torch.Tensor with a batch size of 3, and their corresponding ControlNet states (which will be effectively have double batch size due to num_images_per_prompt=2) will be summed together.

I have some tests you can copy and paste from here: https://github.com/christopher-beckham/diffusers-tests/blob/4b548f8/controlnet_pipeline_cleaner_api/flux.py

(you can run with python -m unittest flux.py)

Other concerns

There are some questions I have however. Why is it that we skip the image preprocessing if the image is torch.Tensor? i.e.

if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

This also seems inconsistent with what is done in the SDXL ControlNet code:

image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)

It may also lead to unexpected behaviour because preprocess explicitly tries to use width and height to preprocess the image (if they are None, then a reasonable default is used instead, depending on what the precise model is). But this logic gets skipped entirely if a torch.Tensor is passed.

Thanks.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings
  • Did you write any new necessary tests? (Yes but in my own standalone repo which I linked to )

Who can review?

@yiyixuxu @wangqixun


return prompt_embeds, pooled_prompt_embeds, text_ids

# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
Copy link
Collaborator

@yiyixuxu yiyixuxu Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so. I think this method was made before we introduced image processor, which set a standard image input format we accept across all our pipelines and check if it is a valid format there

if not is_valid_image_imagelist(image):

@christopher-beckham
Copy link
Contributor Author

@yiyixuxu Thanks that is good to know. I pushed a change so that check_image now only checks the consistency for prompt and image batch size. Though it could maybe do with a more useful name... not sure what to call it, maybe check_image_and_prompt. Let me know if you think it looks good. Thanks.

@christopher-beckham
Copy link
Contributor Author

bump @yiyixuxu thanks!

elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")

if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove almost all of the checks and only do following two things:

  1. make sure if it is a multi controlnet, image is a list with same length of number of controlnet
  2. add a check here
        if image_batch_size == 1:
            repeat_by = batch_size
        elif image_batch_size == batch_size:
            # image batch size is the same as prompt batch size
            repeat_by = num_images_per_prompt
        else:
             raise ValueError("...")

it should be sufficient no? would we miss anything here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the first question, I already added this checking in an earlier commit, see:

https://github.com/christopher-beckham/diffusers/blob/flux_controlnet_input_checking/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L458-L474

For (2), that if elif statement will break everything since the batch_size that is actually passed into prepare_image outside of it is actually batch_size*num_images_per_prompt, i.e:

control_image = self.prepare_image(
    ...
    batch_size=batch_size * num_images_per_prompt,
    num_images_per_prompt=num_images_per_prompt,
    ...
)

It's a little confusing to parse (esp since we also pass num_images_per_prompt into that method) so I changed it to the following:

control_image = self.prepare_image(
    ...
    batch_size=batch_size,
    num_images_per_prompt=num_images_per_prompt,
    ...
)

and made an adjustment to the code inside that method, so now we have:

if image_batch_size == 1:
    repeat_by = batch_size*num_images_per_prompt
elif image_batch_size == batch_size:
    # image batch size is the same as prompt batch size
    repeat_by = num_images_per_prompt
else:
    raise ValueError(...)

I wrote an informative ValueError as well in the event the else statement gets tripped.

I'll push these changes momentarily.

…if statement bypassing preprocess for torch tensor type
@christopher-beckham
Copy link
Contributor Author

Thanks for the comments above @yiyixuxu

Just one last thing, there is this to take care of:

if isinstance(image, torch.Tensor):
pass
else:
image = self.image_processor.preprocess(image, height=height, width=width)

As I previously said, I'm not sure why all the preprocessing gets skipped for torch.Tensor -- maybe it's an oversight by the original code author -- but this is not what happens for the corresponding SDXL controlnet pipeline, which runs self.image_processor.preprocess no matter what.

Fixing this however would side effect code which already uses this class with torch.Tensor. Even if the user sets width=None and height=None in pipeline.__call__ those width and height values will internally be redefined to be 1024:

height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor

I made the change in the latest commit but maybe it's worth discussing this further. If we go with my commit, then maybe it's worth adding in a warning in the event that torch.Tensor is passed.

@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Nov 26, 2024
@christopher-beckham
Copy link
Contributor Author

re-bump @yiyixuxu

@github-actions github-actions bot removed the stale Issues that haven't received updates label Nov 27, 2024
@github-actions
Copy link
Contributor

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Dec 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants