Skip to content

Commit 6af659b

Browse files
committed
Handle t2i adapter in modular denoise
1 parent 7c975f0 commit 6af659b

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
6363
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6464
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
65+
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
6566
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6667
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6768
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -498,6 +499,33 @@ def parse_controlnet_field(
498499
)
499500
)
500501

502+
@staticmethod
503+
def parse_t2i_adapter_field(
504+
exit_stack: ExitStack,
505+
context: InvocationContext,
506+
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
507+
ext_manager: ExtensionsManager,
508+
) -> None:
509+
if t2i_adapters is None:
510+
return
511+
512+
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
513+
if isinstance(t2i_adapters, T2IAdapterField):
514+
t2i_adapters = [t2i_adapters]
515+
516+
for t2i_adapter_field in t2i_adapters:
517+
ext_manager.add_extension(
518+
T2IAdapterExt(
519+
node_context=context,
520+
model_id=t2i_adapter_field.t2i_adapter_model,
521+
image=context.images.get_pil(t2i_adapter_field.image.image_name),
522+
weight=t2i_adapter_field.weight,
523+
begin_step_percent=t2i_adapter_field.begin_step_percent,
524+
end_step_percent=t2i_adapter_field.end_step_percent,
525+
resize_mode=t2i_adapter_field.resize_mode,
526+
)
527+
)
528+
501529
def prep_ip_adapter_image_prompts(
502530
self,
503531
context: InvocationContext,
@@ -840,6 +868,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
840868
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
841869
# ext_manager.add_extension(ext)
842870
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
871+
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
843872

844873
# ext: t2i/ip adapter
845874
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import TYPE_CHECKING, List, Optional, Union
5+
6+
import torch
7+
from diffusers import T2IAdapter
8+
from PIL.Image import Image
9+
10+
from invokeai.app.util.controlnet_utils import prepare_control_image
11+
from invokeai.backend.model_manager import BaseModelType
12+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
13+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
14+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
15+
16+
if TYPE_CHECKING:
17+
from invokeai.app.invocations.model import ModelIdentifierField
18+
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
20+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
21+
22+
23+
class T2IAdapterExt(ExtensionBase):
24+
def __init__(
25+
self,
26+
node_context: InvocationContext,
27+
model_id: ModelIdentifierField,
28+
image: Image,
29+
weight: Union[float, List[float]],
30+
begin_step_percent: float,
31+
end_step_percent: float,
32+
resize_mode: CONTROLNET_RESIZE_VALUES,
33+
):
34+
super().__init__()
35+
self._node_context = node_context
36+
self._model_id = model_id
37+
self._image = image
38+
self._weight = weight
39+
self._resize_mode = resize_mode
40+
self._begin_step_percent = begin_step_percent
41+
self._end_step_percent = end_step_percent
42+
43+
self._adapter_state: Optional[List[torch.Tensor]] = None
44+
45+
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
46+
model_config = self._node_context.models.get_config(self._model_id.key)
47+
if model_config.base == BaseModelType.StableDiffusion1:
48+
self._max_unet_downscale = 8
49+
elif model_config.base == BaseModelType.StableDiffusionXL:
50+
self._max_unet_downscale = 4
51+
else:
52+
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
53+
54+
@callback(ExtensionCallbackType.SETUP)
55+
def setup(self, ctx: DenoiseContext):
56+
t2i_model: T2IAdapter
57+
with self._node_context.models.load(self._model_id) as t2i_model:
58+
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
59+
60+
self._adapter_state = self._run_model(
61+
model=t2i_model,
62+
image=self._image,
63+
latents_height=latents_height,
64+
latents_width=latents_width,
65+
max_unet_downscale=self._max_unet_downscale,
66+
resize_mode=self._resize_mode,
67+
)
68+
69+
def _run_model(
70+
self,
71+
model: T2IAdapter,
72+
image: Image,
73+
latents_height: int,
74+
latents_width: int,
75+
max_unet_downscale: int,
76+
resize_mode: CONTROLNET_RESIZE_VALUES,
77+
):
78+
input_height = latents_height // max_unet_downscale * model.total_downscale_factor
79+
input_width = latents_width // max_unet_downscale * model.total_downscale_factor
80+
81+
t2i_image = prepare_control_image(
82+
image=image,
83+
do_classifier_free_guidance=False,
84+
width=input_width,
85+
height=input_height,
86+
num_channels=model.config["in_channels"], # mypy treats this as a FrozenDict
87+
device=model.device,
88+
dtype=model.dtype,
89+
resize_mode=resize_mode,
90+
)
91+
92+
return model(t2i_image)
93+
94+
@callback(ExtensionCallbackType.PRE_UNET)
95+
def pre_unet_step(self, ctx: DenoiseContext):
96+
# skip if model not active in current step
97+
total_steps = len(ctx.inputs.timesteps)
98+
first_step = math.floor(self._begin_step_percent * total_steps)
99+
last_step = math.ceil(self._end_step_percent * total_steps)
100+
if ctx.step_index < first_step or ctx.step_index > last_step:
101+
return
102+
103+
weight = self._weight
104+
if isinstance(weight, list):
105+
weight = weight[ctx.step_index]
106+
107+
adapter_state = self._adapter_state
108+
if ctx.conditioning_mode == ConditioningMode.Both:
109+
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
110+
111+
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
112+
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
113+
else:
114+
for i, value in enumerate(adapter_state):
115+
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight

0 commit comments

Comments
 (0)