Skip to content

Commit 8d405db

Browse files
committed
refactor StableDiffusionXLControlNetUnion
mode
1 parent 96c376a commit 8d405db

File tree

5 files changed

+185
-310
lines changed

5 files changed

+185
-310
lines changed

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SparseControlNetModel,
1616
SparseControlNetOutput,
1717
)
18-
from .controlnet_union import ControlNetUnionInput, ControlNetUnionInputProMax, ControlNetUnionModel
18+
from .controlnet_union import ControlNetUnionModel
1919
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
2020
from .multicontrolnet import MultiControlNetModel
2121

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 8 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from dataclasses import dataclass
1514
from typing import Any, Dict, List, Optional, Tuple, Union
1615

1716
import torch
1817
from torch import nn
1918

2019
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...image_processor import PipelineImageInput
2220
from ...loaders.single_file_model import FromOriginalModelMixin
2321
from ...utils import logging
2422
from ..attention_processor import (
@@ -40,76 +38,6 @@
4038
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
4139

4240

43-
@dataclass
44-
class ControlNetUnionInput:
45-
"""
46-
The image input of [`ControlNetUnionModel`]:
47-
48-
- 0: openpose
49-
- 1: depth
50-
- 2: hed/pidi/scribble/ted
51-
- 3: canny/lineart/anime_lineart/mlsd
52-
- 4: normal
53-
- 5: segment
54-
"""
55-
56-
openpose: Optional[PipelineImageInput] = None
57-
depth: Optional[PipelineImageInput] = None
58-
hed: Optional[PipelineImageInput] = None
59-
canny: Optional[PipelineImageInput] = None
60-
normal: Optional[PipelineImageInput] = None
61-
segment: Optional[PipelineImageInput] = None
62-
63-
def __len__(self) -> int:
64-
return len(vars(self))
65-
66-
def __iter__(self):
67-
return iter(vars(self))
68-
69-
def __getitem__(self, key):
70-
return getattr(self, key)
71-
72-
def __setitem__(self, key, value):
73-
setattr(self, key, value)
74-
75-
76-
@dataclass
77-
class ControlNetUnionInputProMax:
78-
"""
79-
The image input of [`ControlNetUnionModel`]:
80-
81-
- 0: openpose
82-
- 1: depth
83-
- 2: hed/pidi/scribble/ted
84-
- 3: canny/lineart/anime_lineart/mlsd
85-
- 4: normal
86-
- 5: segment
87-
- 6: tile
88-
- 7: repaint
89-
"""
90-
91-
openpose: Optional[PipelineImageInput] = None
92-
depth: Optional[PipelineImageInput] = None
93-
hed: Optional[PipelineImageInput] = None
94-
canny: Optional[PipelineImageInput] = None
95-
normal: Optional[PipelineImageInput] = None
96-
segment: Optional[PipelineImageInput] = None
97-
tile: Optional[PipelineImageInput] = None
98-
repaint: Optional[PipelineImageInput] = None
99-
100-
def __len__(self) -> int:
101-
return len(vars(self))
102-
103-
def __iter__(self):
104-
return iter(vars(self))
105-
106-
def __getitem__(self, key):
107-
return getattr(self, key)
108-
109-
def __setitem__(self, key, value):
110-
setattr(self, key, value)
111-
112-
11341
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
11442

11543

@@ -680,8 +608,9 @@ def forward(
680608
sample: torch.Tensor,
681609
timestep: Union[torch.Tensor, float, int],
682610
encoder_hidden_states: torch.Tensor,
683-
controlnet_cond: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
611+
controlnet_cond: List[torch.Tensor],
684612
control_type: torch.Tensor,
613+
control_type_idx: List[int],
685614
conditioning_scale: float = 1.0,
686615
class_labels: Optional[torch.Tensor] = None,
687616
timestep_cond: Optional[torch.Tensor] = None,
@@ -701,11 +630,13 @@ def forward(
701630
The number of timesteps to denoise an input.
702631
encoder_hidden_states (`torch.Tensor`):
703632
The encoder hidden states.
704-
controlnet_cond (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
633+
controlnet_cond (`List[torch.Tensor]`):
705634
The conditional input tensors.
706635
control_type (`torch.Tensor`):
707636
A tensor of shape `(batch, num_control_type)` with values `0` or `1` depending on whether the control
708637
type is used.
638+
control_type_idx (`List[int]`):
639+
The indices of `control_type`.
709640
conditioning_scale (`float`, defaults to `1.0`):
710641
The scale factor for ControlNet outputs.
711642
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
@@ -733,20 +664,6 @@ def forward(
733664
If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
734665
returned where the first element is the sample tensor.
735666
"""
736-
if not isinstance(controlnet_cond, (ControlNetUnionInput, ControlNetUnionInputProMax)):
737-
raise ValueError(
738-
"Expected type of `controlnet_cond` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
739-
)
740-
if len(controlnet_cond) != self.config.num_control_type:
741-
if isinstance(controlnet_cond, ControlNetUnionInput):
742-
raise ValueError(
743-
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInputProMax`."
744-
)
745-
elif isinstance(controlnet_cond, ControlNetUnionInputProMax):
746-
raise ValueError(
747-
f"Expected num_control_type {self.config.num_control_type}, got {len(controlnet_cond)}. Try `ControlNetUnionInput`."
748-
)
749-
750667
# check channel order
751668
channel_order = self.config.controlnet_conditioning_channel_order
752669

@@ -830,12 +747,10 @@ def forward(
830747
inputs = []
831748
condition_list = []
832749

833-
for idx, image_type in enumerate(controlnet_cond):
834-
if controlnet_cond[image_type] is None:
835-
continue
836-
condition = self.controlnet_cond_embedding(controlnet_cond[image_type])
750+
for cond, control_idx in zip(controlnet_cond, control_type_idx):
751+
condition = self.controlnet_cond_embedding(cond)
837752
feat_seq = torch.mean(condition, dim=(2, 3))
838-
feat_seq = feat_seq + self.task_embedding[idx]
753+
feat_seq = feat_seq + self.task_embedding[control_idx]
839754
inputs.append(feat_seq.unsqueeze(1))
840755
condition_list.append(condition)
841756

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 60 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
AttnProcessor2_0,
4141
XFormersAttnProcessor,
4242
)
43-
from ...models.controlnets import ControlNetUnionInput, ControlNetUnionInputProMax
4443
from ...models.lora import adjust_lora_scale_text_encoder
4544
from ...schedulers import KarrasDiffusionSchedulers
4645
from ...utils import (
@@ -82,7 +81,6 @@ def retrieve_latents(
8281
Examples:
8382
```py
8483
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
85-
from diffusers.models.controlnets import ControlNetUnionInputProMax
8684
from diffusers.utils import load_image
8785
import torch
8886
import numpy as np
@@ -114,11 +112,8 @@ def retrieve_latents(
114112
mask_np = np.array(mask)
115113
controlnet_img_np[mask_np > 0] = 0
116114
controlnet_img = Image.fromarray(controlnet_img_np)
117-
union_input = ControlNetUnionInputProMax(
118-
repaint=controlnet_img,
119-
)
120115
# generate image
121-
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
116+
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7]).images[0]
122117
image.save("inpaint.png")
123118
```
124119
"""
@@ -1130,7 +1125,7 @@ def __call__(
11301125
prompt_2: Optional[Union[str, List[str]]] = None,
11311126
image: PipelineImageInput = None,
11321127
mask_image: PipelineImageInput = None,
1133-
control_image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
1128+
control_image: PipelineImageInput = None,
11341129
height: Optional[int] = None,
11351130
width: Optional[int] = None,
11361131
padding_mask_crop: Optional[int] = None,
@@ -1158,6 +1153,7 @@ def __call__(
11581153
guess_mode: bool = False,
11591154
control_guidance_start: Union[float, List[float]] = 0.0,
11601155
control_guidance_end: Union[float, List[float]] = 1.0,
1156+
control_mode: Optional[Union[int, List[int]]] = None,
11611157
guidance_rescale: float = 0.0,
11621158
original_size: Tuple[int, int] = None,
11631159
crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1345,20 +1341,6 @@ def __call__(
13451341

13461342
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
13471343

1348-
if not isinstance(control_image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
1349-
raise ValueError(
1350-
"Expected type of `control_image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1351-
)
1352-
if len(control_image_list) != controlnet.config.num_control_type:
1353-
if isinstance(control_image_list, ControlNetUnionInput):
1354-
raise ValueError(
1355-
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInputProMax`."
1356-
)
1357-
elif isinstance(control_image_list, ControlNetUnionInputProMax):
1358-
raise ValueError(
1359-
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(control_image_list)}. Try `ControlNetUnionInput`."
1360-
)
1361-
13621344
# align format for control guidance
13631345
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
13641346
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
@@ -1375,36 +1357,44 @@ def __call__(
13751357
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
13761358
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
13771359

1360+
if not isinstance(control_image, list):
1361+
control_image = [control_image]
1362+
1363+
if not isinstance(control_mode, list):
1364+
control_mode = [control_mode]
1365+
1366+
if len(control_image) != len(control_mode):
1367+
raise ValueError("Expected len(control_image) == len(control_type)")
1368+
1369+
num_control_type = controlnet.config.num_control_type
1370+
13781371
# 1. Check inputs
1379-
control_type = []
1380-
for image_type in control_image_list:
1381-
if control_image_list[image_type]:
1382-
self.check_inputs(
1383-
prompt,
1384-
prompt_2,
1385-
control_image_list[image_type],
1386-
mask_image,
1387-
strength,
1388-
num_inference_steps,
1389-
callback_steps,
1390-
output_type,
1391-
negative_prompt,
1392-
negative_prompt_2,
1393-
prompt_embeds,
1394-
negative_prompt_embeds,
1395-
ip_adapter_image,
1396-
ip_adapter_image_embeds,
1397-
pooled_prompt_embeds,
1398-
negative_pooled_prompt_embeds,
1399-
controlnet_conditioning_scale,
1400-
control_guidance_start,
1401-
control_guidance_end,
1402-
callback_on_step_end_tensor_inputs,
1403-
padding_mask_crop,
1404-
)
1405-
control_type.append(1)
1406-
else:
1407-
control_type.append(0)
1372+
control_type = [0 for _ in range(num_control_type)]
1373+
for _image, control_idx in zip(control_image, control_mode):
1374+
control_type[control_idx] = 1
1375+
self.check_inputs(
1376+
prompt,
1377+
prompt_2,
1378+
_image,
1379+
mask_image,
1380+
strength,
1381+
num_inference_steps,
1382+
callback_steps,
1383+
output_type,
1384+
negative_prompt,
1385+
negative_prompt_2,
1386+
prompt_embeds,
1387+
negative_prompt_embeds,
1388+
ip_adapter_image,
1389+
ip_adapter_image_embeds,
1390+
pooled_prompt_embeds,
1391+
negative_pooled_prompt_embeds,
1392+
controlnet_conditioning_scale,
1393+
control_guidance_start,
1394+
control_guidance_end,
1395+
callback_on_step_end_tensor_inputs,
1396+
padding_mask_crop,
1397+
)
14081398

14091399
control_type = torch.Tensor(control_type)
14101400

@@ -1499,23 +1489,21 @@ def denoising_value_valid(dnv):
14991489
init_image = init_image.to(dtype=torch.float32)
15001490

15011491
# 5.2 Prepare control images
1502-
for image_type in control_image_list:
1503-
if control_image_list[image_type]:
1504-
control_image = self.prepare_control_image(
1505-
image=control_image_list[image_type],
1506-
width=width,
1507-
height=height,
1508-
batch_size=batch_size * num_images_per_prompt,
1509-
num_images_per_prompt=num_images_per_prompt,
1510-
device=device,
1511-
dtype=controlnet.dtype,
1512-
crops_coords=crops_coords,
1513-
resize_mode=resize_mode,
1514-
do_classifier_free_guidance=self.do_classifier_free_guidance,
1515-
guess_mode=guess_mode,
1516-
)
1517-
height, width = control_image.shape[-2:]
1518-
control_image_list[image_type] = control_image
1492+
for idx, _ in enumerate(control_image):
1493+
control_image[idx] = self.prepare_control_image(
1494+
image=control_image[idx],
1495+
width=width,
1496+
height=height,
1497+
batch_size=batch_size * num_images_per_prompt,
1498+
num_images_per_prompt=num_images_per_prompt,
1499+
device=device,
1500+
dtype=controlnet.dtype,
1501+
crops_coords=crops_coords,
1502+
resize_mode=resize_mode,
1503+
do_classifier_free_guidance=self.do_classifier_free_guidance,
1504+
guess_mode=guess_mode,
1505+
)
1506+
height, width = control_image[idx].shape[-2:]
15191507

15201508
# 5.3 Prepare mask
15211509
mask = self.mask_processor.preprocess(
@@ -1589,6 +1577,9 @@ def denoising_value_valid(dnv):
15891577

15901578
original_size = original_size or (height, width)
15911579
target_size = target_size or (height, width)
1580+
for _image in control_image:
1581+
if isinstance(_image, torch.Tensor):
1582+
original_size = original_size or _image.shape[-2:]
15921583

15931584
# 10. Prepare added time ids & embeddings
15941585
add_text_embeds = pooled_prompt_embeds
@@ -1693,8 +1684,9 @@ def denoising_value_valid(dnv):
16931684
control_model_input,
16941685
t,
16951686
encoder_hidden_states=controlnet_prompt_embeds,
1696-
controlnet_cond=control_image_list,
1687+
controlnet_cond=control_image,
16971688
control_type=control_type,
1689+
control_type_idx=control_mode,
16981690
conditioning_scale=cond_scale,
16991691
guess_mode=guess_mode,
17001692
added_cond_kwargs=controlnet_added_cond_kwargs,

0 commit comments

Comments
 (0)