Skip to content

Commit 42356ec

Browse files
committed
Add ControlNet support to denoise
1 parent f9c61f1 commit 42356ec

File tree

2 files changed

+213
-16
lines changed

2 files changed

+213
-16
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
61+
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
6162
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6263
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6364
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
@@ -463,6 +464,39 @@ def prep_control_data(
463464

464465
return controlnet_data
465466

467+
@staticmethod
468+
def parse_controlnet_field(
469+
exit_stack: ExitStack,
470+
context: InvocationContext,
471+
control_input: ControlField | list[ControlField] | None,
472+
ext_manager: ExtensionsManager,
473+
) -> None:
474+
# Normalize control_input to a list.
475+
control_list: list[ControlField]
476+
if isinstance(control_input, ControlField):
477+
control_list = [control_input]
478+
elif isinstance(control_input, list):
479+
control_list = control_input
480+
elif control_input is None:
481+
control_list = []
482+
else:
483+
raise ValueError(f"Unexpected control_input type: {type(control_input)}")
484+
485+
for control_info in control_list:
486+
model = exit_stack.enter_context(context.models.load(control_info.control_model))
487+
ext_manager.add_extension(
488+
ControlNetExt(
489+
model=model,
490+
image=context.images.get_pil(control_info.image.image_name),
491+
weight=control_info.control_weight,
492+
begin_step_percent=control_info.begin_step_percent,
493+
end_step_percent=control_info.end_step_percent,
494+
control_mode=control_info.control_mode,
495+
resize_mode=control_info.resize_mode,
496+
)
497+
)
498+
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
499+
466500
def prep_ip_adapter_image_prompts(
467501
self,
468502
context: InvocationContext,
@@ -790,22 +824,30 @@ def step_callback(state: PipelineIntermediateState) -> None:
790824

791825
ext_manager.add_extension(PreviewExt(step_callback))
792826

793-
# ext: t2i/ip adapter
794-
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
795-
796-
unet_info = context.models.load(self.unet.unet)
797-
assert isinstance(unet_info.model, UNet2DConditionModel)
798-
with (
799-
unet_info.model_on_device() as (model_state_dict, unet),
800-
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
801-
# ext: controlnet
802-
ext_manager.patch_extensions(unet),
803-
# ext: freeu, seamless, ip adapter, lora
804-
ext_manager.patch_unet(model_state_dict, unet),
805-
):
806-
sd_backend = StableDiffusionBackend(unet, scheduler)
807-
denoise_ctx.unet = unet
808-
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
827+
# context for loading additional models
828+
with ExitStack() as exit_stack:
829+
# later should be smth like:
830+
# for extension_field in self.extensions:
831+
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
832+
# ext_manager.add_extension(ext)
833+
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
834+
835+
# ext: t2i/ip adapter
836+
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
837+
838+
unet_info = context.models.load(self.unet.unet)
839+
assert isinstance(unet_info.model, UNet2DConditionModel)
840+
with (
841+
unet_info.model_on_device() as (model_state_dict, unet),
842+
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
843+
# ext: controlnet
844+
ext_manager.patch_extensions(denoise_ctx),
845+
# ext: freeu, seamless, ip adapter, lora
846+
ext_manager.patch_unet(model_state_dict, unet),
847+
):
848+
sd_backend = StableDiffusionBackend(unet, scheduler)
849+
denoise_ctx.unet = unet
850+
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
809851

810852
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
811853
result_latents = result_latents.detach().to("cpu")
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from contextlib import contextmanager
5+
from typing import TYPE_CHECKING, List, Optional, Union
6+
7+
import torch
8+
from PIL.Image import Image
9+
10+
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
11+
from invokeai.app.util.controlnet_utils import prepare_control_image
12+
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
13+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
14+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
15+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
16+
17+
if TYPE_CHECKING:
18+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
19+
from invokeai.backend.util.hotfixes import ControlNetModel
20+
21+
22+
class ControlNetExt(ExtensionBase):
23+
def __init__(
24+
self,
25+
model: ControlNetModel,
26+
image: Image,
27+
weight: Union[float, List[float]],
28+
begin_step_percent: float,
29+
end_step_percent: float,
30+
control_mode: str,
31+
resize_mode: str,
32+
):
33+
super().__init__()
34+
self.model = model
35+
self.image = image
36+
self.weight = weight
37+
self.begin_step_percent = begin_step_percent
38+
self.end_step_percent = end_step_percent
39+
self.control_mode = control_mode
40+
self.resize_mode = resize_mode
41+
42+
self.image_tensor: Optional[torch.Tensor] = None
43+
44+
@contextmanager
45+
def patch_extension(self, ctx: DenoiseContext):
46+
try:
47+
original_processors = self.model.attn_processors
48+
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())
49+
50+
yield None
51+
finally:
52+
self.model.set_attn_processor(original_processors)
53+
54+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
55+
def resize_image(self, ctx: DenoiseContext):
56+
_, _, latent_height, latent_width = ctx.latents.shape
57+
image_height = latent_height * LATENT_SCALE_FACTOR
58+
image_width = latent_width * LATENT_SCALE_FACTOR
59+
60+
self.image_tensor = prepare_control_image(
61+
image=self.image,
62+
do_classifier_free_guidance=False,
63+
width=image_width,
64+
height=image_height,
65+
# batch_size=batch_size * num_images_per_prompt,
66+
# num_images_per_prompt=num_images_per_prompt,
67+
device=ctx.latents.device,
68+
dtype=ctx.latents.dtype,
69+
control_mode=self.control_mode,
70+
resize_mode=self.resize_mode,
71+
)
72+
73+
@callback(ExtensionCallbackType.PRE_UNET)
74+
def pre_unet_step(self, ctx: DenoiseContext):
75+
# skip if model not active in current step
76+
total_steps = len(ctx.inputs.timesteps)
77+
first_step = math.floor(self.begin_step_percent * total_steps)
78+
last_step = math.ceil(self.end_step_percent * total_steps)
79+
if ctx.step_index < first_step or ctx.step_index > last_step:
80+
return
81+
82+
# convert mode to internal flags
83+
soft_injection = self.control_mode in ["more_prompt", "more_control"]
84+
cfg_injection = self.control_mode in ["more_control", "unbalanced"]
85+
86+
# no negative conditioning in cfg_injection mode
87+
if cfg_injection:
88+
if ctx.conditioning_mode == ConditioningMode.Negative:
89+
return
90+
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)
91+
92+
if ctx.conditioning_mode == ConditioningMode.Both:
93+
# add zeros as samples for negative conditioning
94+
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
95+
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])
96+
97+
else:
98+
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)
99+
100+
if (
101+
ctx.unet_kwargs.down_block_additional_residuals is None
102+
and ctx.unet_kwargs.mid_block_additional_residual is None
103+
):
104+
ctx.unet_kwargs.down_block_additional_residuals = down_samples
105+
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
106+
else:
107+
# add controlnet outputs together if have multiple controlnets
108+
ctx.unet_kwargs.down_block_additional_residuals = [
109+
samples_prev + samples_curr
110+
for samples_prev, samples_curr in zip(
111+
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
112+
)
113+
]
114+
ctx.unet_kwargs.mid_block_additional_residual += mid_sample
115+
116+
def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
117+
total_steps = len(ctx.inputs.timesteps)
118+
119+
model_input = ctx.latent_model_input
120+
image_tensor = self.image_tensor
121+
if conditioning_mode == ConditioningMode.Both:
122+
model_input = torch.cat([model_input] * 2)
123+
image_tensor = torch.cat([image_tensor] * 2)
124+
125+
cn_unet_kwargs = UNetKwargs(
126+
sample=model_input,
127+
timestep=ctx.timestep,
128+
encoder_hidden_states=None, # set later by conditoning
129+
cross_attention_kwargs=dict( # noqa: C408
130+
percent_through=ctx.step_index / total_steps,
131+
),
132+
)
133+
134+
ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)
135+
136+
# get static weight, or weight corresponding to current step
137+
weight = self.weight
138+
if isinstance(weight, list):
139+
weight = weight[ctx.step_index]
140+
141+
tmp_kwargs = vars(cn_unet_kwargs)
142+
tmp_kwargs.pop("down_block_additional_residuals", None)
143+
tmp_kwargs.pop("mid_block_additional_residual", None)
144+
tmp_kwargs.pop("down_intrablock_additional_residuals", None)
145+
146+
# controlnet(s) inference
147+
down_samples, mid_sample = self.model(
148+
controlnet_cond=image_tensor,
149+
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
150+
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
151+
return_dict=False,
152+
**vars(cn_unet_kwargs),
153+
)
154+
155+
return down_samples, mid_sample

0 commit comments

Comments
 (0)