Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
SparseControlNetModel,
SparseControlNetOutput,
)
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .multicontrolnet import MultiControlNetModel

Expand Down
101 changes: 8 additions & 93 deletions src/diffusers/models/controlnets/controlnet_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch import nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...image_processor import PipelineImageInput
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ..attention_processor import (
Expand All @@ -40,76 +38,6 @@
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module


@dataclass
class ControlNetUnionInput:
"""
The image input of [`ControlNetUnionModel`]:

- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
"""

openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None

def __len__(self) -> int:
return len(vars(self))

def __iter__(self):
return iter(vars(self))

def __getitem__(self, key):
return getattr(self, key)

def __setitem__(self, key, value):
setattr(self, key, value)


@dataclass
class ControlNetUnionInputProMax:
"""
The image input of [`ControlNetUnionModel`]:

- 0: openpose
- 1: depth
- 2: hed/pidi/scribble/ted
- 3: canny/lineart/anime_lineart/mlsd
- 4: normal
- 5: segment
- 6: tile
- 7: repaint
"""

openpose: Optional[PipelineImageInput] = None
depth: Optional[PipelineImageInput] = None
hed: Optional[PipelineImageInput] = None
canny: Optional[PipelineImageInput] = None
normal: Optional[PipelineImageInput] = None
segment: Optional[PipelineImageInput] = None
tile: Optional[PipelineImageInput] = None
repaint: Optional[PipelineImageInput] = None

def __len__(self) -> int:
return len(vars(self))

def __iter__(self):
return iter(vars(self))

def __getitem__(self, key):
return getattr(self, key)

def __setitem__(self, key, value):
setattr(self, key, value)


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


Expand Down Expand Up @@ -680,8 +608,9 @@ def forward(
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
controlnet_cond: List[torch.Tensor],
control_type: torch.Tensor,
control_type_idx: List[int],
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
Expand All @@ -701,11 +630,13 @@ def forward(
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
controlnet_cond (`List[torch.Tensor]`):
The conditional input tensors.
control_type (`torch.Tensor`):
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
type is used.
control_type_idx (`List[int]`):
The indices of `control_type`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -733,20 +664,6 @@ def forward(
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(controlnet_cond) != self.config.num_control_type:
if isinstance(controlnet_cond, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
)

# check channel order
channel_order = self.config.controlnet_conditioning_channel_order

Expand Down Expand Up @@ -830,12 +747,10 @@ def forward(
inputs = []
condition_list = []

for idx, image_type in enumerate(controlnet_cond):
if controlnet_cond[image_type] is None:
continue
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
for cond, control_idx in zip(controlnet_cond, control_type_idx):
condition = self.controlnet_cond_embedding(cond)
feat_seq = torch.mean(condition, dim=(2, 3))
feat_seq = feat_seq + self.task_embedding[idx]
feat_seq = feat_seq + self.task_embedding[control_idx]
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(condition)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
AttnProcessor2_0,
XFormersAttnProcessor,
)
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
Expand Down Expand Up @@ -82,7 +81,6 @@ def retrieve_latents(
Examples:
```py
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.models.controlnets import ControlNetUnionInputProMax
from diffusers.utils import load_image
import torch
import numpy as np
Expand Down Expand Up @@ -114,11 +112,8 @@ def retrieve_latents(
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
union_input = ControlNetUnionInputProMax(
repaint=controlnet_img,
)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
image.save("inpaint.png")
```
"""
Expand Down Expand Up @@ -1130,7 +1125,7 @@ def __call__(
prompt_2: Optional[Union[str, List[str]]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
padding_mask_crop: Optional[int] = None,
Expand Down Expand Up @@ -1158,6 +1153,7 @@ def __call__(
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_mode: Optional[Union[int, List[int]]] = None,
guidance_rescale: float = 0.0,
original_size: Tuple[int, int] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
Expand Down Expand Up @@ -1345,20 +1341,6 @@ def __call__(

controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet

if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
raise ValueError(
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
)
if len(control_image_list) != controlnet.config.num_control_type:
if isinstance(control_image_list, ControlNetUnionInput):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
)
elif isinstance(control_image_list, ControlNetUnionInputProMax):
raise ValueError(
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
)

# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
Expand All @@ -1375,36 +1357,44 @@ def __call__(
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]

if not isinstance(control_image, list):
control_image = [control_image]

if not isinstance(control_mode, list):
control_mode = [control_mode]

if len(control_image) != len(control_mode):
raise ValueError("Expected len(control_image) == len(control_type)")

num_control_type = controlnet.config.num_control_type

# 1. Check inputs
control_type = []
for image_type in control_image_list:
if control_image_list[image_type]:
self.check_inputs(
prompt,
prompt_2,
control_image_list[image_type],
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)
control_type.append(1)
else:
control_type.append(0)
control_type = [0 for _ in range(num_control_type)]
for _image, control_idx in zip(control_image, control_mode):
control_type[control_idx] = 1
self.check_inputs(
prompt,
prompt_2,
_image,
mask_image,
strength,
num_inference_steps,
callback_steps,
output_type,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
ip_adapter_image,
ip_adapter_image_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
padding_mask_crop,
)

control_type = torch.Tensor(control_type)

Expand Down Expand Up @@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv):
init_image = init_image.to(dtype=torch.float32)

# 5.2 Prepare control images
for image_type in control_image_list:
if control_image_list[image_type]:
control_image = self.prepare_control_image(
image=control_image_list[image_type],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image.shape[-2:]
control_image_list[image_type] = control_image
for idx, _ in enumerate(control_image):
control_image[idx] = self.prepare_control_image(
image=control_image[idx],
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
crops_coords=crops_coords,
resize_mode=resize_mode,
do_classifier_free_guidance=self.do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = control_image[idx].shape[-2:]

# 5.3 Prepare mask
mask = self.mask_processor.preprocess(
Expand Down Expand Up @@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv):

original_size = original_size or (height, width)
target_size = target_size or (height, width)
for _image in control_image:
if isinstance(_image, torch.Tensor):
original_size = original_size or _image.shape[-2:]

# 10. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
Expand Down Expand Up @@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv):
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=control_image_list,
controlnet_cond=control_image,
control_type=control_type,
control_type_idx=control_mode,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
added_cond_kwargs=controlnet_added_cond_kwargs,
Expand Down
Loading
Loading