Skip to content

Commit c5cb0b1

Browse files
committed
remove MultiControlNetModel
1 parent 713bf6b commit c5cb0b1

File tree

3 files changed

+1
-195
lines changed

3 files changed

+1
-195
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 1 addition & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from ...utils.torch_utils import is_compiled_module, randn_tensor
5656
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
5757
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
58-
from .multicontrolnet import MultiControlNetModel
5958

6059

6160
if is_invisible_watermark_available():
@@ -252,9 +251,6 @@ def __init__(
252251
if not isinstance(controlnet, ControlNetUnionModel):
253252
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
254253

255-
if isinstance(controlnet, (list, tuple)):
256-
controlnet = MultiControlNetModel(controlnet)
257-
258254
self.register_modules(
259255
vae=vae,
260256
text_encoder=text_encoder,
@@ -754,15 +750,6 @@ def check_inputs(
754750
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
755751
)
756752

757-
# `prompt` needs more sophisticated handling when there are multiple
758-
# conditionings.
759-
if isinstance(self.controlnet, MultiControlNetModel):
760-
if isinstance(prompt, list):
761-
logger.warning(
762-
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
763-
" prompts. The conditionings will be fixed across the prompts."
764-
)
765-
766753
# Check `image`
767754
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
768755
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -779,25 +766,7 @@ def check_inputs(
779766
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
780767
):
781768
self.check_image(image, prompt, prompt_embeds)
782-
elif (
783-
isinstance(self.controlnet, MultiControlNetModel)
784-
or is_compiled
785-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
786-
):
787-
if not isinstance(image, list):
788-
raise TypeError("For multiple controlnets: `image` must be type `list`")
789-
790-
# When `image` is a nested list:
791-
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
792-
elif any(isinstance(i, list) for i in image):
793-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
794-
elif len(image) != len(self.controlnet.nets):
795-
raise ValueError(
796-
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
797-
)
798769

799-
for image_ in image:
800-
self.check_image(image_, prompt, prompt_embeds)
801770
else:
802771
assert False
803772

@@ -818,21 +787,6 @@ def check_inputs(
818787
if not isinstance(controlnet_conditioning_scale, float):
819788
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
820789

821-
elif (
822-
isinstance(self.controlnet, MultiControlNetModel)
823-
or is_compiled
824-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
825-
):
826-
if isinstance(controlnet_conditioning_scale, list):
827-
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
828-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
829-
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
830-
self.controlnet.nets
831-
):
832-
raise ValueError(
833-
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
834-
" the same length as the number of controlnets"
835-
)
836790
else:
837791
assert False
838792

@@ -847,12 +801,6 @@ def check_inputs(
847801
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
848802
)
849803

850-
if isinstance(self.controlnet, MultiControlNetModel):
851-
if len(control_guidance_start) != len(self.controlnet.nets):
852-
raise ValueError(
853-
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
854-
)
855-
856804
for start, end in zip(control_guidance_start, control_guidance_end):
857805
if start >= end:
858806
raise ValueError(
@@ -1422,12 +1370,6 @@ def __call__(
14221370
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
14231371
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
14241372
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1425-
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1426-
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1427-
control_guidance_start, control_guidance_end = (
1428-
mult * [control_guidance_start],
1429-
mult * [control_guidance_end],
1430-
)
14311373

14321374
# # 0.0 Default height and width to unet
14331375
# height = height or self.unet.config.sample_size * self.vae_scale_factor
@@ -1438,12 +1380,6 @@ def __call__(
14381380
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
14391381
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
14401382
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1441-
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1442-
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1443-
control_guidance_start, control_guidance_end = (
1444-
mult * [control_guidance_start],
1445-
mult * [control_guidance_end],
1446-
)
14471383

14481384
# 1. Check inputs
14491385
control_type = []
@@ -1493,9 +1429,6 @@ def __call__(
14931429

14941430
device = self._execution_device
14951431

1496-
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1497-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1498-
14991432
# 3. Encode input prompt
15001433
text_encoder_lora_scale = (
15011434
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
@@ -1666,7 +1599,7 @@ def denoising_value_valid(dnv):
16661599
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
16671600
for s, e in zip(control_guidance_start, control_guidance_end)
16681601
]
1669-
controlnet_keep.append(keeps if isinstance(controlnet, MultiControlNetModel) else keeps[0])
1602+
controlnet_keep.append(keeps[0])
16701603

16711604
# 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
16721605
height, width = latents.shape[-2:]

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@
6262
if is_invisible_watermark_available():
6363
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
6464

65-
from .multicontrolnet import MultiControlNetModel
66-
67-
6865
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6966

7067

@@ -264,9 +261,6 @@ def __init__(
264261
if not isinstance(controlnet, ControlNetUnionModel):
265262
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
266263

267-
if isinstance(controlnet, (list, tuple)):
268-
controlnet = MultiControlNetModel(controlnet)
269-
270264
self.register_modules(
271265
vae=vae,
272266
text_encoder=text_encoder,
@@ -697,15 +691,6 @@ def check_inputs(
697691
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
698692
)
699693

700-
# `prompt` needs more sophisticated handling when there are multiple
701-
# conditionings.
702-
if isinstance(self.controlnet, MultiControlNetModel):
703-
if isinstance(prompt, list):
704-
logger.warning(
705-
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
706-
" prompts. The conditionings will be fixed across the prompts."
707-
)
708-
709694
# Check `image`
710695
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
711696
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -722,25 +707,7 @@ def check_inputs(
722707
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
723708
):
724709
self.check_image(image, prompt, prompt_embeds)
725-
elif (
726-
isinstance(self.controlnet, MultiControlNetModel)
727-
or is_compiled
728-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
729-
):
730-
if not isinstance(image, list):
731-
raise TypeError("For multiple controlnets: `image` must be type `list`")
732-
733-
# When `image` is a nested list:
734-
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
735-
elif any(isinstance(i, list) for i in image):
736-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
737-
elif len(image) != len(self.controlnet.nets):
738-
raise ValueError(
739-
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
740-
)
741710

742-
for image_ in image:
743-
self.check_image(image_, prompt, prompt_embeds)
744711
else:
745712
assert False
746713

@@ -761,21 +728,6 @@ def check_inputs(
761728
if not isinstance(controlnet_conditioning_scale, float):
762729
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
763730

764-
elif (
765-
isinstance(self.controlnet, MultiControlNetModel)
766-
or is_compiled
767-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
768-
):
769-
if isinstance(controlnet_conditioning_scale, list):
770-
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
771-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
772-
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
773-
self.controlnet.nets
774-
):
775-
raise ValueError(
776-
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
777-
" the same length as the number of controlnets"
778-
)
779731
else:
780732
assert False
781733

@@ -790,12 +742,6 @@ def check_inputs(
790742
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
791743
)
792744

793-
if isinstance(self.controlnet, MultiControlNetModel):
794-
if len(control_guidance_start) != len(self.controlnet.nets):
795-
raise ValueError(
796-
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
797-
)
798-
799745
for start, end in zip(control_guidance_start, control_guidance_end):
800746
if start >= end:
801747
raise ValueError(
@@ -1252,12 +1198,6 @@ def __call__(
12521198
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
12531199
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
12541200
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1255-
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1256-
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1257-
control_guidance_start, control_guidance_end = (
1258-
mult * [control_guidance_start],
1259-
mult * [control_guidance_end],
1260-
)
12611201

12621202
# 1. Check inputs. Raise error if not correct
12631203
control_type = []
@@ -1303,9 +1243,6 @@ def __call__(
13031243

13041244
device = self._execution_device
13051245

1306-
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1307-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1308-
13091246
global_pool_conditions = (
13101247
controlnet.config.global_pool_conditions
13111248
if isinstance(controlnet, ControlNetModel)

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@
6262
if is_invisible_watermark_available():
6363
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
6464

65-
from .multicontrolnet import MultiControlNetModel
66-
67-
6865
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6966

7067

@@ -256,9 +253,6 @@ def __init__(
256253
if not isinstance(controlnet, ControlNetUnionModel):
257254
raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
258255

259-
if isinstance(controlnet, (list, tuple)):
260-
controlnet = MultiControlNetModel(controlnet)
261-
262256
self.register_modules(
263257
vae=vae,
264258
text_encoder=text_encoder,
@@ -702,15 +696,6 @@ def check_inputs(
702696
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
703697
)
704698

705-
# `prompt` needs more sophisticated handling when there are multiple
706-
# conditionings.
707-
if isinstance(self.controlnet, MultiControlNetModel):
708-
if isinstance(prompt, list):
709-
logger.warning(
710-
f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
711-
" prompts. The conditionings will be fixed across the prompts."
712-
)
713-
714699
# Check `image`
715700
is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
716701
self.controlnet, torch._dynamo.eval_frame.OptimizedModule
@@ -727,25 +712,6 @@ def check_inputs(
727712
and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
728713
):
729714
self.check_image(image, prompt, prompt_embeds)
730-
elif (
731-
isinstance(self.controlnet, MultiControlNetModel)
732-
or is_compiled
733-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
734-
):
735-
if not isinstance(image, list):
736-
raise TypeError("For multiple controlnets: `image` must be type `list`")
737-
738-
# When `image` is a nested list:
739-
# (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
740-
elif any(isinstance(i, list) for i in image):
741-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
742-
elif len(image) != len(self.controlnet.nets):
743-
raise ValueError(
744-
f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
745-
)
746-
747-
for image_ in image:
748-
self.check_image(image_, prompt, prompt_embeds)
749715
else:
750716
assert False
751717

@@ -766,21 +732,6 @@ def check_inputs(
766732
if not isinstance(controlnet_conditioning_scale, float):
767733
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
768734

769-
elif (
770-
isinstance(self.controlnet, MultiControlNetModel)
771-
or is_compiled
772-
and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
773-
):
774-
if isinstance(controlnet_conditioning_scale, list):
775-
if any(isinstance(i, list) for i in controlnet_conditioning_scale):
776-
raise ValueError("A single batch of multiple conditionings are supported at the moment.")
777-
elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
778-
self.controlnet.nets
779-
):
780-
raise ValueError(
781-
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
782-
" the same length as the number of controlnets"
783-
)
784735
else:
785736
assert False
786737

@@ -795,12 +746,6 @@ def check_inputs(
795746
f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
796747
)
797748

798-
if isinstance(self.controlnet, MultiControlNetModel):
799-
if len(control_guidance_start) != len(self.controlnet.nets):
800-
raise ValueError(
801-
f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
802-
)
803-
804749
for start, end in zip(control_guidance_start, control_guidance_end):
805750
if start >= end:
806751
raise ValueError(
@@ -1343,12 +1288,6 @@ def __call__(
13431288
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
13441289
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
13451290
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1346-
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1347-
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1348-
control_guidance_start, control_guidance_end = (
1349-
mult * [control_guidance_start],
1350-
mult * [control_guidance_end],
1351-
)
13521291

13531292
# 1. Check inputs. Raise error if not correct
13541293
control_type = []
@@ -1395,9 +1334,6 @@ def __call__(
13951334

13961335
device = self._execution_device
13971336

1398-
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1399-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1400-
14011337
global_pool_conditions = (
14021338
controlnet.config.global_pool_conditions
14031339
if isinstance(controlnet, (ControlNetModel, ControlNetUnionModel))

0 commit comments

Comments
 (0)