Skip to content

Commit 5930bcf

Browse files
committed
fixes
1 parent 6cc3323 commit 5930bcf

File tree

7 files changed

+109
-199
lines changed

7 files changed

+109
-199
lines changed

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
)
102102
from .controlnets import (
103103
ControlNetModel,
104+
ControlNetUnionModel,
104105
ControlNetXSAdapter,
105106
FluxControlNetModel,
106107
FluxMultiControlNetModel,

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SparseControlNetModel,
1616
SparseControlNetOutput,
1717
)
18+
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
1819
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
1920
from .multicontrolnet import MultiControlNetModel
2021

src/diffusers/models/controlnet_union.py renamed to src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,31 @@
1818
from torch import nn
1919
from transformers.activations import QuickGELUActivation as QuickGELU
2020

21-
from ..configuration_utils import ConfigMixin, register_to_config
22-
from ..image_processor import PipelineImageInput
23-
from ..loaders.single_file_model import FromOriginalModelMixin
24-
from ..utils import BaseInput, logging
25-
from .attention_processor import (
21+
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...image_processor import PipelineImageInput
23+
from ...loaders.single_file_model import FromOriginalModelMixin
24+
from ...utils import logging
25+
from ..attention_processor import (
2626
ADDED_KV_ATTENTION_PROCESSORS,
2727
CROSS_ATTENTION_PROCESSORS,
2828
AttentionProcessor,
2929
AttnAddedKVProcessor,
3030
AttnProcessor,
3131
)
32-
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
33-
from .embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
34-
from .modeling_utils import ModelMixin
35-
from .unets.unet_2d_blocks import (
32+
from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
33+
from ..modeling_utils import ModelMixin
34+
from ..unets.unet_2d_blocks import (
3635
CrossAttnDownBlock2D,
3736
DownBlock2D,
3837
UNetMidBlock2DCrossAttn,
3938
get_down_block,
4039
)
41-
from .unets.unet_2d_condition import UNet2DConditionModel
40+
from ..unets.unet_2d_condition import UNet2DConditionModel
41+
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
4242

4343

4444
@dataclass
45-
class ControlNetUnionInput(BaseInput):
45+
class ControlNetUnionInput:
4646
"""
4747
The image input of [`ControlNetUnionModel`]:
4848
@@ -54,18 +54,27 @@ class ControlNetUnionInput(BaseInput):
5454
- 5: segment
5555
"""
5656

57-
openpose: PipelineImageInput = None
58-
depth: PipelineImageInput = None
59-
hed: PipelineImageInput = None
60-
canny: PipelineImageInput = None
61-
normal: PipelineImageInput = None
62-
segment: PipelineImageInput = None
57+
openpose: Optional[PipelineImageInput] = None
58+
depth: Optional[PipelineImageInput] = None
59+
hed: Optional[PipelineImageInput] = None
60+
canny: Optional[PipelineImageInput] = None
61+
normal: Optional[PipelineImageInput] = None
62+
segment: Optional[PipelineImageInput] = None
63+
64+
def __len__(self) -> int:
65+
return len(vars(self))
66+
67+
def __iter__(self):
68+
return iter(vars(self))
69+
70+
def __getitem__(self, key):
71+
return getattr(self, key)
6372

6473

6574
@dataclass
66-
class ControlNetUnionInputProMax(BaseInput):
75+
class ControlNetUnionInputProMax:
6776
"""
68-
The image input of [`ControlNetUnionModel`] for ProMax variants:
77+
The image input of [`ControlNetUnionModel`]:
6978
7079
- 0: openpose
7180
- 1: depth
@@ -77,14 +86,23 @@ class ControlNetUnionInputProMax(BaseInput):
7786
- 7: repaint
7887
"""
7988

80-
openpose: PipelineImageInput = None
81-
depth: PipelineImageInput = None
82-
hed: PipelineImageInput = None
83-
canny: PipelineImageInput = None
84-
normal: PipelineImageInput = None
85-
segment: PipelineImageInput = None
86-
tile: PipelineImageInput = None
87-
repaint: PipelineImageInput = None
89+
openpose: Optional[PipelineImageInput] = None
90+
depth: Optional[PipelineImageInput] = None
91+
hed: Optional[PipelineImageInput] = None
92+
canny: Optional[PipelineImageInput] = None
93+
normal: Optional[PipelineImageInput] = None
94+
segment: Optional[PipelineImageInput] = None
95+
tile: Optional[PipelineImageInput] = None
96+
repaint: Optional[PipelineImageInput] = None
97+
98+
def __len__(self) -> int:
99+
return len(vars(self))
100+
101+
def __iter__(self):
102+
return iter(vars(self))
103+
104+
def __getitem__(self, key):
105+
return getattr(self, key)
88106

89107

90108
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
AttnProcessor2_0,
4141
XFormersAttnProcessor,
4242
)
43-
from ...models.controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax
43+
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
4444
from ...models.lora import adjust_lora_scale_text_encoder
4545
from ...schedulers import KarrasDiffusionSchedulers
4646
from ...utils import (
@@ -605,6 +605,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
605605
extra_step_kwargs["generator"] = generator
606606
return extra_step_kwargs
607607

608+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
608609
def check_image(self, image, prompt, prompt_embeds):
609610
image_is_pil = isinstance(image, PIL.Image.Image)
610611
image_is_tensor = isinstance(image, torch.Tensor)
@@ -826,6 +827,7 @@ def check_inputs(
826827
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
827828
)
828829

830+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
829831
def prepare_control_image(
830832
self,
831833
image,
@@ -860,6 +862,7 @@ def prepare_control_image(
860862

861863
return image
862864

865+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_latents
863866
def prepare_latents(
864867
self,
865868
batch_size,
@@ -927,6 +930,7 @@ def prepare_latents(
927930

928931
return outputs
929932

933+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline._encode_vae_image
930934
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
931935
dtype = image.dtype
932936
if self.vae.config.force_upcast:
@@ -950,6 +954,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
950954

951955
return image_latents
952956

957+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_inpaint_sd_xl.StableDiffusionXLControlNetInpaintPipeline.prepare_mask_latents
953958
def prepare_mask_latents(
954959
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
955960
):
@@ -1560,7 +1565,7 @@ def denoising_value_valid(dnv):
15601565
latents, noise = latents_outputs
15611566

15621567
# 7. Prepare mask latent variables
1563-
mask, masked_image_latents = self.prepare_mask_latents(
1568+
mask, _ = self.prepare_mask_latents(
15641569
mask,
15651570
masked_image,
15661571
batch_size * num_images_per_prompt,
@@ -1573,19 +1578,7 @@ def denoising_value_valid(dnv):
15731578
)
15741579

15751580
# 8. Check that sizes of mask, masked image and latents match
1576-
if num_channels_unet == 9:
1577-
# default case for runwayml/stable-diffusion-inpainting
1578-
num_channels_mask = mask.shape[1]
1579-
num_channels_masked_image = masked_image_latents.shape[1]
1580-
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1581-
raise ValueError(
1582-
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1583-
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1584-
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1585-
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1586-
" `pipeline.unet` or your `mask_image` or `image` input."
1587-
)
1588-
elif num_channels_unet != 4:
1581+
if num_channels_unet != 4:
15891582
raise ValueError(
15901583
f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
15911584
)
@@ -1673,7 +1666,6 @@ def denoising_value_valid(dnv):
16731666
# expand the latents if we are doing classifier free guidance
16741667
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
16751668

1676-
# concat latents, mask, masked_image_latents in the channel dimension
16771669
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
16781670

16791671
added_cond_kwargs = {
@@ -1730,9 +1722,6 @@ def denoising_value_valid(dnv):
17301722
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
17311723
added_cond_kwargs["image_embeds"] = image_embeds
17321724

1733-
if num_channels_unet == 9:
1734-
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1735-
17361725
# predict the noise residual
17371726
noise_pred = self.unet(
17381727
latent_model_input,
@@ -1757,20 +1746,19 @@ def denoising_value_valid(dnv):
17571746
# compute the previous noisy sample x_t -> x_t-1
17581747
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
17591748

1760-
if num_channels_unet == 4:
1761-
init_latents_proper = image_latents
1762-
if self.do_classifier_free_guidance:
1763-
init_mask, _ = mask.chunk(2)
1764-
else:
1765-
init_mask = mask
1749+
init_latents_proper = image_latents
1750+
if self.do_classifier_free_guidance:
1751+
init_mask, _ = mask.chunk(2)
1752+
else:
1753+
init_mask = mask
17661754

1767-
if i < len(timesteps) - 1:
1768-
noise_timestep = timesteps[i + 1]
1769-
init_latents_proper = self.scheduler.add_noise(
1770-
init_latents_proper, noise, torch.tensor([noise_timestep])
1771-
)
1755+
if i < len(timesteps) - 1:
1756+
noise_timestep = timesteps[i + 1]
1757+
init_latents_proper = self.scheduler.add_noise(
1758+
init_latents_proper, noise, torch.tensor([noise_timestep])
1759+
)
17721760

1773-
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1761+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
17741762

17751763
if callback_on_step_end is not None:
17761764
callback_kwargs = {}

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import inspect
1717
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

19+
import numpy as np
20+
import PIL.Image
1921
import torch
2022
import torch.nn.functional as F
2123
from transformers import (
@@ -41,7 +43,7 @@
4143
AttnProcessor2_0,
4244
XFormersAttnProcessor,
4345
)
44-
from ...models.controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax
46+
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
4547
from ...models.lora import adjust_lora_scale_text_encoder
4648
from ...schedulers import KarrasDiffusionSchedulers
4749
from ...utils import (
@@ -608,6 +610,44 @@ def prepare_extra_step_kwargs(self, generator, eta):
608610
extra_step_kwargs["generator"] = generator
609611
return extra_step_kwargs
610612

613+
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.check_image
614+
def check_image(self, image, prompt, prompt_embeds):
615+
image_is_pil = isinstance(image, PIL.Image.Image)
616+
image_is_tensor = isinstance(image, torch.Tensor)
617+
image_is_np = isinstance(image, np.ndarray)
618+
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
619+
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
620+
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
621+
622+
if (
623+
not image_is_pil
624+
and not image_is_tensor
625+
and not image_is_np
626+
and not image_is_pil_list
627+
and not image_is_tensor_list
628+
and not image_is_np_list
629+
):
630+
raise TypeError(
631+
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
632+
)
633+
634+
if image_is_pil:
635+
image_batch_size = 1
636+
else:
637+
image_batch_size = len(image)
638+
639+
if prompt is not None and isinstance(prompt, str):
640+
prompt_batch_size = 1
641+
elif prompt is not None and isinstance(prompt, list):
642+
prompt_batch_size = len(prompt)
643+
elif prompt_embeds is not None:
644+
prompt_batch_size = prompt_embeds.shape[0]
645+
646+
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
647+
raise ValueError(
648+
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
649+
)
650+
611651
def check_inputs(
612652
self,
613653
prompt,
@@ -1228,7 +1268,7 @@ def __call__(
12281268
# 4. Prepare image
12291269
for image_type in image:
12301270
if image[image_type]:
1231-
image = self.prepare_image(
1271+
image[image_type] = self.prepare_image(
12321272
image=image[image_type],
12331273
width=width,
12341274
height=height,
@@ -1240,7 +1280,6 @@ def __call__(
12401280
guess_mode=guess_mode,
12411281
)
12421282
height, width = image.shape[-2:]
1243-
image[image_type] = image
12441283

12451284
# 5. Prepare timesteps
12461285
timesteps, num_inference_steps = retrieve_timesteps(

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
AttnProcessor2_0,
4444
XFormersAttnProcessor,
4545
)
46-
from ...models.controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax
46+
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
4747
from ...models.lora import adjust_lora_scale_text_encoder
4848
from ...schedulers import KarrasDiffusionSchedulers
4949
from ...utils import (

0 commit comments

Comments
 (0)