-
Couldn't load subscription status.
- Fork 6.4k
Image dimension checking for ControlNet FLUX #9550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| from typing import Any, Callable, Dict, List, Optional, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import PIL | ||
| import torch | ||
| from transformers import ( | ||
| CLIPTextModel, | ||
|
|
@@ -389,10 +390,31 @@ def encode_prompt( | |
|
|
||
| return prompt_embeds, pooled_prompt_embeds, text_ids | ||
|
|
||
| # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image | ||
| def check_image(self, image, prompt, prompt_embeds): | ||
| image_is_pil = isinstance(image, PIL.Image.Image) | ||
| if image_is_pil: | ||
| image_batch_size = 1 | ||
| else: | ||
| image_batch_size = len(image) | ||
|
|
||
| if prompt is not None and isinstance(prompt, str): | ||
| prompt_batch_size = 1 | ||
| elif prompt is not None and isinstance(prompt, list): | ||
| prompt_batch_size = len(prompt) | ||
| elif prompt_embeds is not None: | ||
| prompt_batch_size = prompt_embeds.shape[0] | ||
|
|
||
| if image_batch_size != 1 and image_batch_size != prompt_batch_size: | ||
| raise ValueError( | ||
| f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" | ||
| ) | ||
|
|
||
| def check_inputs( | ||
| self, | ||
| prompt, | ||
| prompt_2, | ||
| image, | ||
| height, | ||
| width, | ||
| prompt_embeds=None, | ||
|
|
@@ -429,6 +451,30 @@ def check_inputs( | |
| elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): | ||
| raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") | ||
|
|
||
| if ( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove almost all of the checks and only do following two things:
if image_batch_size == 1:
repeat_by = batch_size
elif image_batch_size == batch_size:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
else:
raise ValueError("...")it should be sufficient no? would we miss anything here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the first question, I already added this checking in an earlier commit, see: For (2), that if elif statement will break everything since the It's a little confusing to parse (esp since we also pass and made an adjustment to the code inside that method, so now we have: I wrote an informative ValueError as well in the event the I'll push these changes momentarily. |
||
| isinstance(self.controlnet, FluxControlNetModel) | ||
| ): | ||
| self.check_image(image, prompt, prompt_embeds) | ||
| elif ( | ||
| isinstance(self.controlnet, FluxMultiControlNetModel) | ||
| ): | ||
| if not isinstance(image, list): | ||
| raise TypeError("For multiple controlnets: `image` must be type `list`") | ||
|
|
||
| # When `image` is a nested list: | ||
| # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) | ||
| elif any(isinstance(i, list) for i in image): | ||
| raise ValueError("A single batch of multiple conditionings are supported at the moment.") | ||
| elif len(image) != len(self.controlnet.nets): | ||
| raise ValueError( | ||
| f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." | ||
| ) | ||
|
|
||
| for image_ in image: | ||
| self.check_image(image_, prompt, prompt_embeds) | ||
| else: | ||
| assert False | ||
|
|
||
| if prompt_embeds is not None and pooled_prompt_embeds is None: | ||
| raise ValueError( | ||
| "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." | ||
|
|
@@ -523,18 +569,20 @@ def prepare_image( | |
| do_classifier_free_guidance=False, | ||
| guess_mode=False, | ||
| ): | ||
| if isinstance(image, torch.Tensor): | ||
| pass | ||
| else: | ||
| image = self.image_processor.preprocess(image, height=height, width=width) | ||
|
|
||
|
|
||
| image = self.image_processor.preprocess(image, height=height, width=width) | ||
| image_batch_size = image.shape[0] | ||
|
|
||
| if image_batch_size == 1: | ||
| repeat_by = batch_size | ||
| else: | ||
| repeat_by = batch_size*num_images_per_prompt | ||
| elif image_batch_size == batch_size: | ||
| # image batch size is the same as prompt batch size | ||
| repeat_by = num_images_per_prompt | ||
| else: | ||
| raise ValueError( | ||
| "`image_batch_size` must be either 1 or equal to the prompt " + \ | ||
| f"batch size, which is {batch_size}." | ||
| ) | ||
|
|
||
| image = image.repeat_interleave(repeat_by, dim=0) | ||
|
|
||
|
|
@@ -678,6 +726,7 @@ def __call__( | |
| self.check_inputs( | ||
| prompt, | ||
| prompt_2, | ||
| control_image, | ||
| height, | ||
| width, | ||
| prompt_embeds=prompt_embeds, | ||
|
|
@@ -726,7 +775,7 @@ def __call__( | |
| image=control_image, | ||
| width=width, | ||
| height=height, | ||
| batch_size=batch_size * num_images_per_prompt, | ||
| batch_size=batch_size, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=self.vae.dtype, | ||
|
|
@@ -762,7 +811,7 @@ def __call__( | |
| image=control_image_, | ||
| width=width, | ||
| height=height, | ||
| batch_size=batch_size * num_images_per_prompt, | ||
| batch_size=batch_size, | ||
| num_images_per_prompt=num_images_per_prompt, | ||
| device=device, | ||
| dtype=self.vae.dtype, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so. I think this method was made before we introduced image processor, which set a standard image input format we accept across all our pipelines and check if it is a valid format there
diffusers/src/diffusers/image_processor.py
Line 535 in c4a8979