Skip to content

Commit c788d8e

Browse files
committed
Refactor to use ControlNetUnionInput, introduce BaseInput
1 parent 68cbdb3 commit c788d8e

File tree

8 files changed

+329
-74
lines changed

8 files changed

+329
-74
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@
219219
title: Logging
220220
- local: api/outputs
221221
title: Outputs
222+
- local: api/inputs
223+
title: Inputs
222224
- local: api/quantization
223225
title: Quantization
224226
title: Main Classes

docs/source/en/api/inputs.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Inputs
14+
15+
Some model inputs are subclasses of [`~utils.BaseInput`], data structures containing all the information needed by the model. The inputs can also be used as tuples or dictionaries.
16+
17+
For example:
18+
19+
```python
20+
from diffusers.models.controlnet_union import ControlNetUnionInput
21+
22+
union_input = ControlNetUnionInput(
23+
openpose=...
24+
)
25+
```
26+
27+
When considering the `inputs` object as a tuple, it considers all the attributes including those that have `None` values.
28+
29+
<Tip>
30+
31+
To check a specific pipeline or model input, refer to its corresponding API documentation.
32+
33+
</Tip>
34+
35+
## BaseInput
36+
37+
[[autodoc]] utils.BaseInput
38+
- to_tuple

src/diffusers/models/controlnet_union.py

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from dataclasses import dataclass
1415
from typing import Any, Dict, List, Optional, Tuple, Union
1516

1617
import torch
1718
from torch import nn
1819
from transformers.activations import QuickGELUActivation as QuickGELU
1920

2021
from ..configuration_utils import ConfigMixin, register_to_config
22+
from ..image_processor import PipelineImageInput
2123
from ..loaders.single_file_model import FromOriginalModelMixin
22-
from ..utils import logging
24+
from ..utils import BaseInput, logging
2325
from .attention_processor import (
2426
ADDED_KV_ATTENTION_PROCESSORS,
2527
CROSS_ATTENTION_PROCESSORS,
@@ -39,6 +41,52 @@
3941
from .unets.unet_2d_condition import UNet2DConditionModel
4042

4143

44+
@dataclass
45+
class ControlNetUnionInput(BaseInput):
46+
"""
47+
The image input of [`ControlNetUnionModel`]:
48+
49+
- 0: openpose
50+
- 1: depth
51+
- 2: hed/pidi/scribble/ted
52+
- 3: canny/lineart/anime_lineart/mlsd
53+
- 4: normal
54+
- 5: segment
55+
"""
56+
57+
openpose: PipelineImageInput = None
58+
depth: PipelineImageInput = None
59+
hed: PipelineImageInput = None
60+
canny: PipelineImageInput = None
61+
normal: PipelineImageInput = None
62+
segment: PipelineImageInput = None
63+
64+
65+
@dataclass
66+
class ControlNetUnionInputProMax(BaseInput):
67+
"""
68+
The image input of [`ControlNetUnionModel`] for ProMax variants:
69+
70+
- 0: openpose
71+
- 1: depth
72+
- 2: hed/pidi/scribble/ted
73+
- 3: canny/lineart/anime_lineart/mlsd
74+
- 4: normal
75+
- 5: segment
76+
- 6: tile
77+
- 7: repaint
78+
"""
79+
80+
openpose: PipelineImageInput = None
81+
depth: PipelineImageInput = None
82+
hed: PipelineImageInput = None
83+
canny: PipelineImageInput = None
84+
normal: PipelineImageInput = None
85+
segment: PipelineImageInput = None
86+
tile: PipelineImageInput = None
87+
repaint: PipelineImageInput = None
88+
89+
4290
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4391

4492

@@ -624,7 +672,8 @@ def forward(
624672
sample: torch.Tensor,
625673
timestep: Union[torch.Tensor, float, int],
626674
encoder_hidden_states: torch.Tensor,
627-
controlnet_cond_list: List[torch.Tensor],
675+
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
676+
control_type: torch.Tensor,
628677
conditioning_scale: float = 1.0,
629678
class_labels: Optional[torch.Tensor] = None,
630679
timestep_cond: Optional[torch.Tensor] = None,
@@ -644,8 +693,11 @@ def forward(
644693
The number of timesteps to denoise an input.
645694
encoder_hidden_states (`torch.Tensor`):
646695
The encoder hidden states.
647-
controlnet_cond_list (`List[torch.Tensor]`):
648-
List of the conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
696+
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
697+
The conditional input tensors.
698+
control_type (`torch.Tensor`):
699+
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the
700+
control type is used.
649701
conditioning_scale (`float`, defaults to `1.0`):
650702
The scale factor for ControlNet outputs.
651703
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
@@ -743,7 +795,6 @@ def forward(
743795
add_embeds = add_embeds.to(emb.dtype)
744796
aug_emb = self.add_embedding(add_embeds)
745797

746-
control_type = added_cond_kwargs.get("control_type")
747798
control_embeds = self.control_type_proj(control_type.flatten())
748799
control_embeds = control_embeds.reshape((t_emb.shape[0], -1))
749800
control_embeds = control_embeds.to(emb.dtype)
@@ -753,32 +804,33 @@ def forward(
753804

754805
# 2. pre-process
755806
sample = self.conv_in(sample)
756-
indices = torch.nonzero(control_type[0])
757807

758808
inputs = []
759809
condition_list = []
760810

761-
for idx in range(indices.shape[0] + 1):
762-
if idx == indices.shape[0]:
763-
controlnet_cond = sample
764-
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
765-
else:
766-
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond_list[indices[idx][0]])
767-
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
768-
feat_seq = feat_seq + self.task_embedding[indices[idx][0]]
769-
811+
for idx, image_type in enumerate(controlnet_cond):
812+
if controlnet_cond[image_type] is None:
813+
continue
814+
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
815+
feat_seq = torch.mean(condition, dim=(2, 3))
816+
feat_seq = feat_seq + self.task_embedding[idx]
770817
inputs.append(feat_seq.unsqueeze(1))
771-
condition_list.append(controlnet_cond)
818+
condition_list.append(condition)
819+
820+
condition = sample
821+
feat_seq = torch.mean(condition, dim=(2, 3))
822+
inputs.append(feat_seq.unsqueeze(1))
823+
condition_list.append(condition)
772824

773825
x = torch.cat(inputs, dim=1)
774826
for layer in self.transformer_layers:
775827
x = layer(x)
776828

777829
controlnet_cond_fuser = sample * 0.0
778-
for idx in range(indices.shape[0]):
830+
for idx, condition in enumerate(condition_list):
779831
alpha = self.spatial_ch_projs(x[:, idx])
780832
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
781-
controlnet_cond_fuser += condition_list[idx] + alpha
833+
controlnet_cond_fuser += condition + alpha
782834

783835
sample = sample + controlnet_cond_fuser
784836

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
AttnProcessor2_0,
4141
XFormersAttnProcessor,
4242
)
43+
from ...models.controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax
4344
from ...models.lora import adjust_lora_scale_text_encoder
4445
from ...schedulers import KarrasDiffusionSchedulers
4546
from ...utils import (
@@ -1184,10 +1185,7 @@ def __call__(
11841185
prompt_2: Optional[Union[str, List[str]]] = None,
11851186
image: PipelineImageInput = None,
11861187
mask_image: PipelineImageInput = None,
1187-
control_image_list: Union[
1188-
PipelineImageInput,
1189-
List[PipelineImageInput],
1190-
] = None,
1188+
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
11911189
height: Optional[int] = None,
11921190
width: Optional[int] = None,
11931191
padding_mask_crop: Optional[int] = None,
@@ -1226,8 +1224,6 @@ def __call__(
12261224
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
12271225
] = None,
12281226
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1229-
union_control=False,
1230-
union_control_type=None,
12311227
**kwargs,
12321228
):
12331229
r"""
@@ -1433,12 +1429,13 @@ def __call__(
14331429
)
14341430

14351431
# 1. Check inputs
1436-
for control_image in control_image_list:
1437-
if control_image:
1432+
control_type = []
1433+
for image_type in control_image_list:
1434+
if control_image_list[image_type]:
14381435
self.check_inputs(
14391436
prompt,
14401437
prompt_2,
1441-
control_image,
1438+
control_image_list[image_type],
14421439
mask_image,
14431440
strength,
14441441
num_inference_steps,
@@ -1458,6 +1455,11 @@ def __call__(
14581455
callback_on_step_end_tensor_inputs,
14591456
padding_mask_crop,
14601457
)
1458+
control_type.append(1)
1459+
else:
1460+
control_type.append(0)
1461+
1462+
control_type = torch.Tensor(control_type)
14611463

14621464
self._guidance_scale = guidance_scale
14631465
self._clip_skip = clip_skip
@@ -1553,10 +1555,10 @@ def denoising_value_valid(dnv):
15531555
init_image = init_image.to(dtype=torch.float32)
15541556

15551557
# 5.2 Prepare control images
1556-
for idx in range(len(control_image_list)):
1557-
if control_image_list[idx]:
1558+
for image_type in control_image_list:
1559+
if control_image_list[image_type]:
15581560
control_image = self.prepare_control_image(
1559-
image=control_image_list[idx],
1561+
image=control_image_list[image_type],
15601562
width=width,
15611563
height=height,
15621564
batch_size=batch_size * num_images_per_prompt,
@@ -1569,7 +1571,7 @@ def denoising_value_valid(dnv):
15691571
guess_mode=guess_mode,
15701572
)
15711573
height, width = control_image.shape[-2:]
1572-
control_image_list[idx] = control_image
1574+
control_image_list[image_type] = control_image
15731575

15741576
# 5.3 Prepare mask
15751577
mask = self.mask_processor.preprocess(
@@ -1709,6 +1711,11 @@ def denoising_value_valid(dnv):
17091711
num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
17101712
timesteps = timesteps[:num_inference_steps]
17111713

1714+
control_type = (
1715+
control_type.reshape(1, -1)
1716+
.to(device, dtype=prompt_embeds.dtype)
1717+
.repeat(batch_size * num_images_per_prompt * 2, 1)
1718+
)
17121719
with self.progress_bar(total=num_inference_steps) as progress_bar:
17131720
for i, t in enumerate(timesteps):
17141721
if self.interrupt:
@@ -1723,9 +1730,6 @@ def denoising_value_valid(dnv):
17231730
added_cond_kwargs = {
17241731
"text_embeds": add_text_embeds,
17251732
"time_ids": add_time_ids,
1726-
"control_type": union_control_type.reshape(1, -1)
1727-
.to(device, dtype=prompt_embeds.dtype)
1728-
.repeat(batch_size * num_images_per_prompt * 2, 1),
17291733
}
17301734

17311735
# controlnet(s) inference
@@ -1759,7 +1763,8 @@ def denoising_value_valid(dnv):
17591763
control_model_input,
17601764
t,
17611765
encoder_hidden_states=controlnet_prompt_embeds,
1762-
controlnet_cond_list=control_image_list,
1766+
controlnet_cond=control_image_list,
1767+
control_type=control_type,
17631768
conditioning_scale=cond_scale,
17641769
guess_mode=guess_mode,
17651770
added_cond_kwargs=controlnet_added_cond_kwargs,

0 commit comments

Comments
 (0)