Skip to content

Commit 6cc3323

Browse files
committed
Address changes
1 parent 3ecb041 commit 6cc3323

File tree

2 files changed

+37
-120
lines changed

2 files changed

+37
-120
lines changed

src/diffusers/models/controlnet_union.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
AttnProcessor,
3131
)
3232
from .controlnet import ControlNetConditioningEmbedding, ControlNetOutput, zero_module
33-
from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
33+
from .embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
3434
from .modeling_utils import ModelMixin
3535
from .unets.unet_2d_blocks import (
3636
CrossAttnDownBlock2D,
@@ -282,32 +282,8 @@ def __init__(
282282
act_fn=act_fn,
283283
)
284284

285-
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
286-
encoder_hid_dim_type = "text_proj"
287-
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
288-
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
289-
290-
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
291-
raise ValueError(
292-
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
293-
)
294-
295-
if encoder_hid_dim_type == "text_proj":
296-
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
297-
elif encoder_hid_dim_type == "text_image_proj":
298-
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
299-
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
300-
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
301-
self.encoder_hid_proj = TextImageProjection(
302-
text_embed_dim=encoder_hid_dim,
303-
image_embed_dim=cross_attention_dim,
304-
cross_attention_dim=cross_attention_dim,
305-
)
306-
307-
elif encoder_hid_dim_type is not None:
308-
raise ValueError(
309-
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
310-
)
285+
if encoder_hid_dim_type is not None:
286+
raise ValueError(f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None.")
311287
else:
312288
self.encoder_hid_proj = None
313289

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 34 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import inspect
1717
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1818

19-
import numpy as np
20-
import PIL.Image
2119
import torch
2220
import torch.nn.functional as F
2321
from transformers import (
@@ -48,7 +46,6 @@
4846
from ...schedulers import KarrasDiffusionSchedulers
4947
from ...utils import (
5048
USE_PEFT_BACKEND,
51-
deprecate,
5249
logging,
5350
replace_example_docstring,
5451
scale_lora_layers,
@@ -615,8 +612,7 @@ def check_inputs(
615612
self,
616613
prompt,
617614
prompt_2,
618-
image,
619-
callback_steps,
615+
image: PipelineImageInput,
620616
negative_prompt=None,
621617
negative_prompt_2=None,
622618
prompt_embeds=None,
@@ -630,12 +626,6 @@ def check_inputs(
630626
control_guidance_end=1.0,
631627
callback_on_step_end_tensor_inputs=None,
632628
):
633-
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
634-
raise ValueError(
635-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
636-
f" {type(callback_steps)}."
637-
)
638-
639629
if callback_on_step_end_tensor_inputs is not None and not all(
640630
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
641631
):
@@ -767,43 +757,25 @@ def check_inputs(
767757
f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
768758
)
769759

770-
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
771-
def check_image(self, image, prompt, prompt_embeds):
772-
image_is_pil = isinstance(image, PIL.Image.Image)
773-
image_is_tensor = isinstance(image, torch.Tensor)
774-
image_is_np = isinstance(image, np.ndarray)
775-
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
776-
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
777-
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
778-
779-
if (
780-
not image_is_pil
781-
and not image_is_tensor
782-
and not image_is_np
783-
and not image_is_pil_list
784-
and not image_is_tensor_list
785-
and not image_is_np_list
786-
):
787-
raise TypeError(
788-
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
789-
)
790-
791-
if image_is_pil:
792-
image_batch_size = 1
793-
else:
794-
image_batch_size = len(image)
795-
796-
if prompt is not None and isinstance(prompt, str):
797-
prompt_batch_size = 1
798-
elif prompt is not None and isinstance(prompt, list):
799-
prompt_batch_size = len(prompt)
800-
elif prompt_embeds is not None:
801-
prompt_batch_size = prompt_embeds.shape[0]
760+
def check_input(
761+
self,
762+
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax],
763+
):
764+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
802765

803-
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
766+
if not isinstance(image, (ControlNetUnionInput, ControlNetUnionInputProMax)):
804767
raise ValueError(
805-
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
768+
"Expected type of `image` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
806769
)
770+
if len(image) != controlnet.config.num_control_type:
771+
if isinstance(image, ControlNetUnionInput):
772+
raise ValueError(
773+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInputProMax`."
774+
)
775+
elif isinstance(image, ControlNetUnionInputProMax):
776+
raise ValueError(
777+
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image)}. Try `ControlNetUnionInput`."
778+
)
807779

808780
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
809781
def prepare_image(
@@ -823,9 +795,11 @@ def prepare_image(
823795

824796
if image_batch_size == 1:
825797
repeat_by = batch_size
826-
else:
798+
elif image_batch_size == batch_size:
827799
# image batch size is the same as prompt batch size
828800
repeat_by = num_images_per_prompt
801+
else:
802+
raise ValueError(f"Expected image batch size == 1 or `batch_size`, got {image_batch_size}.")
829803

830804
image = image.repeat_interleave(repeat_by, dim=0)
831805

@@ -964,7 +938,7 @@ def __call__(
964938
self,
965939
prompt: Union[str, List[str]] = None,
966940
prompt_2: Optional[Union[str, List[str]]] = None,
967-
image_list: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
941+
image: Union[ControlNetUnionInput, ControlNetUnionInputProMax] = None,
968942
height: Optional[int] = None,
969943
width: Optional[int] = None,
970944
num_inference_steps: int = 50,
@@ -1002,7 +976,6 @@ def __call__(
1002976
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1003977
] = None,
1004978
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1005-
**kwargs,
1006979
):
1007980
r"""
1008981
The call function to the pipeline for generation.
@@ -1013,7 +986,7 @@ def __call__(
1013986
prompt_2 (`str` or `List[str]`, *optional*):
1014987
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1015988
used in both text-encoders.
1016-
image_list (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
989+
image (`Union[ControlNetUnionInput, ControlNetUnionInputProMax]`):
1017990
In turn this supports (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`,
1018991
`List[PIL.Image.Image]`, `List[np.ndarray]`, `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]`
1019992
or `List[List[PIL.Image.Image]]`):
@@ -1158,40 +1131,12 @@ def __call__(
11581131
otherwise a `tuple` is returned containing the output images.
11591132
"""
11601133

1161-
callback = kwargs.pop("callback", None)
1162-
callback_steps = kwargs.pop("callback_steps", None)
1163-
1164-
if callback is not None:
1165-
deprecate(
1166-
"callback",
1167-
"1.0.0",
1168-
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1169-
)
1170-
if callback_steps is not None:
1171-
deprecate(
1172-
"callback_steps",
1173-
"1.0.0",
1174-
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
1175-
)
1176-
11771134
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
11781135
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
11791136

11801137
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
11811138

1182-
if not isinstance(image_list, (ControlNetUnionInput, ControlNetUnionInputProMax)):
1183-
raise ValueError(
1184-
"Expected type of `image_list` to be one of `ControlNetUnionInput` or `ControlNetUnionInputProMax`"
1185-
)
1186-
if len(image_list) != controlnet.config.num_control_type:
1187-
if isinstance(image_list, ControlNetUnionInput):
1188-
raise ValueError(
1189-
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image_list)}. Try `ControlNetUnionInputProMax`."
1190-
)
1191-
elif isinstance(image_list, ControlNetUnionInputProMax):
1192-
raise ValueError(
1193-
f"Expected num_control_type {controlnet.config.num_control_type}, got {len(image_list)}. Try `ControlNetUnionInput`."
1194-
)
1139+
self.check_input(image)
11951140

11961141
# align format for control guidance
11971142
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
@@ -1201,13 +1146,12 @@ def __call__(
12011146

12021147
# 1. Check inputs. Raise error if not correct
12031148
control_type = []
1204-
for image_type in image_list:
1205-
if image_list[image_type]:
1149+
for image_type in image:
1150+
if image[image_type]:
12061151
self.check_inputs(
12071152
prompt,
12081153
prompt_2,
1209-
image_list[image_type],
1210-
callback_steps,
1154+
image[image_type],
12111155
negative_prompt,
12121156
negative_prompt_2,
12131157
prompt_embeds,
@@ -1282,10 +1226,10 @@ def __call__(
12821226
)
12831227

12841228
# 4. Prepare image
1285-
for image_type in image_list:
1286-
if image_list[image_type]:
1229+
for image_type in image:
1230+
if image[image_type]:
12871231
image = self.prepare_image(
1288-
image=image_list[image_type],
1232+
image=image[image_type],
12891233
width=width,
12901234
height=height,
12911235
batch_size=batch_size * num_images_per_prompt,
@@ -1296,7 +1240,7 @@ def __call__(
12961240
guess_mode=guess_mode,
12971241
)
12981242
height, width = image.shape[-2:]
1299-
image_list[image_type] = image
1243+
image[image_type] = image
13001244

13011245
# 5. Prepare timesteps
13021246
timesteps, num_inference_steps = retrieve_timesteps(
@@ -1337,9 +1281,9 @@ def __call__(
13371281
)
13381282

13391283
# 7.2 Prepare added time ids & embeddings
1340-
for image_type in image_list:
1341-
if isinstance(image_list[image_type], torch.Tensor):
1342-
original_size = original_size or image_list[image_type].shape[-2:]
1284+
for image_type in image:
1285+
if isinstance(image[image_type], torch.Tensor):
1286+
original_size = original_size or image[image_type].shape[-2:]
13431287

13441288
target_size = target_size or (height, width)
13451289
add_text_embeds = pooled_prompt_embeds
@@ -1449,7 +1393,7 @@ def __call__(
14491393
control_model_input,
14501394
t,
14511395
encoder_hidden_states=controlnet_prompt_embeds,
1452-
controlnet_cond=image_list,
1396+
controlnet_cond=image,
14531397
control_type=control_type,
14541398
conditioning_scale=cond_scale,
14551399
guess_mode=guess_mode,
@@ -1508,9 +1452,6 @@ def __call__(
15081452
# call the callback, if provided
15091453
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
15101454
progress_bar.update()
1511-
if callback is not None and i % callback_steps == 0:
1512-
step_idx = i // getattr(self.scheduler, "order", 1)
1513-
callback(step_idx, t, latents)
15141455

15151456
if not output_type == "latent":
15161457
# make sure the VAE is in float32 mode, as it overflows in float16

0 commit comments

Comments
 (0)