Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 58 additions & 16 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
Expand Down Expand Up @@ -463,6 +464,39 @@ def prep_control_data(

return controlnet_data

@staticmethod
def parse_controlnet_field(
exit_stack: ExitStack,
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
ext_manager: ExtensionsManager,
) -> None:
# Normalize control_input to a list.
control_list: list[ControlField]
if isinstance(control_input, ControlField):
control_list = [control_input]
elif isinstance(control_input, list):
control_list = control_input
elif control_input is None:
control_list = []
else:
raise ValueError(f"Unexpected control_input type: {type(control_input)}")

for control_info in control_list:
model = exit_stack.enter_context(context.models.load(control_info.control_model))
ext_manager.add_extension(
ControlNetExt(
model=model,
image=context.images.get_pil(control_info.image.image_name),
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]

def prep_ip_adapter_image_prompts(
self,
context: InvocationContext,
Expand Down Expand Up @@ -790,22 +824,30 @@ def step_callback(state: PipelineIntermediateState) -> None:

ext_manager.add_extension(PreviewExt(step_callback))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(unet),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)
# context for loading additional models
with ExitStack() as exit_stack:
# later should be smth like:
# for extension_field in self.extensions:
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
# ext_manager.add_extension(ext)
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

unet_info = context.models.load(self.unet.unet)
assert isinstance(unet_info.model, UNet2DConditionModel)
with (
unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls),
# ext: controlnet
ext_manager.patch_extensions(denoise_ctx),
# ext: freeu, seamless, ip adapter, lora
ext_manager.patch_unet(model_state_dict, unet),
):
sd_backend = StableDiffusionBackend(unet, scheduler)
denoise_ctx.unet = unet
result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.detach().to("cpu")
Expand Down
155 changes: 155 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/controlnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

import math
from contextlib import contextmanager
from typing import TYPE_CHECKING, List, Optional, Union

import torch
from PIL.Image import Image

from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
from invokeai.backend.util.hotfixes import ControlNetModel


class ControlNetExt(ExtensionBase):
def __init__(
self,
model: ControlNetModel,
image: Image,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
control_mode: str,
resize_mode: str,
):
super().__init__()
self.model = model
self.image = image
self.weight = weight
self.begin_step_percent = begin_step_percent
self.end_step_percent = end_step_percent
self.control_mode = control_mode
self.resize_mode = resize_mode

self.image_tensor: Optional[torch.Tensor] = None

@contextmanager
def patch_extension(self, ctx: DenoiseContext):
try:
original_processors = self.model.attn_processors
self.model.set_attn_processor(ctx.inputs.attention_processor_cls())

yield None
finally:
self.model.set_attn_processor(original_processors)

@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def resize_image(self, ctx: DenoiseContext):
_, _, latent_height, latent_width = ctx.latents.shape
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR

self.image_tensor = prepare_control_image(
image=self.image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=ctx.latents.device,
dtype=ctx.latents.dtype,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
)

@callback(ExtensionCallbackType.PRE_UNET)
def pre_unet_step(self, ctx: DenoiseContext):
# skip if model not active in current step
total_steps = len(ctx.inputs.timesteps)
first_step = math.floor(self.begin_step_percent * total_steps)
last_step = math.ceil(self.end_step_percent * total_steps)
if ctx.step_index < first_step or ctx.step_index > last_step:
return

# convert mode to internal flags
soft_injection = self.control_mode in ["more_prompt", "more_control"]
cfg_injection = self.control_mode in ["more_control", "unbalanced"]

# no negative conditioning in cfg_injection mode
if cfg_injection:
if ctx.conditioning_mode == ConditioningMode.Negative:
return
down_samples, mid_sample = self._run(ctx, soft_injection, ConditioningMode.Positive)

if ctx.conditioning_mode == ConditioningMode.Both:
# add zeros as samples for negative conditioning
down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples]
mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample])

else:
down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode)

if (
ctx.unet_kwargs.down_block_additional_residuals is None
and ctx.unet_kwargs.mid_block_additional_residual is None
):
ctx.unet_kwargs.down_block_additional_residuals = down_samples
ctx.unet_kwargs.mid_block_additional_residual = mid_sample
else:
# add controlnet outputs together if have multiple controlnets
ctx.unet_kwargs.down_block_additional_residuals = [
samples_prev + samples_curr
for samples_prev, samples_curr in zip(
ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True
)
]
ctx.unet_kwargs.mid_block_additional_residual += mid_sample

def _run(self, ctx: DenoiseContext, soft_injection: bool, conditioning_mode: ConditioningMode):
total_steps = len(ctx.inputs.timesteps)

model_input = ctx.latent_model_input
image_tensor = self.image_tensor
if conditioning_mode == ConditioningMode.Both:
model_input = torch.cat([model_input] * 2)
image_tensor = torch.cat([image_tensor] * 2)

cn_unet_kwargs = UNetKwargs(
sample=model_input,
timestep=ctx.timestep,
encoder_hidden_states=None, # set later by conditoning
cross_attention_kwargs=dict( # noqa: C408
percent_through=ctx.step_index / total_steps,
),
)

ctx.inputs.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode)

# get static weight, or weight corresponding to current step
weight = self.weight
if isinstance(weight, list):
weight = weight[ctx.step_index]

tmp_kwargs = vars(cn_unet_kwargs)
tmp_kwargs.pop("down_block_additional_residuals", None)
tmp_kwargs.pop("mid_block_additional_residual", None)
tmp_kwargs.pop("down_intrablock_additional_residuals", None)

# controlnet(s) inference
down_samples, mid_sample = self.model(
controlnet_cond=image_tensor,
conditioning_scale=weight, # controlnet specific, NOT the guidance scale
guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel
return_dict=False,
**vars(cn_unet_kwargs),
)

return down_samples, mid_sample