Skip to content

Commit ee7503c

Browse files
authored
Modular backend - T2I Adapter (#6662)
## Summary T2I Adapter code from #6577. ## Related Issues / Discussions #6606 https://invokeai.notion.site/Modular-Stable-Diffusion-Backend-Design-Document-e8952daab5d5472faecdc4a72d377b0d ## QA Instructions Run with and without set `USE_MODULAR_DENOISE` environment. ## Merge Plan Nope. If you think that there should be some kind of tests - feel free to add. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
2 parents e8e2482 + 310719e commit ee7503c

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6464
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6565
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
66+
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
6667
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6768
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6869
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -499,6 +500,33 @@ def parse_controlnet_field(
499500
)
500501
)
501502

503+
@staticmethod
504+
def parse_t2i_adapter_field(
505+
exit_stack: ExitStack,
506+
context: InvocationContext,
507+
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
508+
ext_manager: ExtensionsManager,
509+
) -> None:
510+
if t2i_adapters is None:
511+
return
512+
513+
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
514+
if isinstance(t2i_adapters, T2IAdapterField):
515+
t2i_adapters = [t2i_adapters]
516+
517+
for t2i_adapter_field in t2i_adapters:
518+
ext_manager.add_extension(
519+
T2IAdapterExt(
520+
node_context=context,
521+
model_id=t2i_adapter_field.t2i_adapter_model,
522+
image=context.images.get_pil(t2i_adapter_field.image.image_name),
523+
weight=t2i_adapter_field.weight,
524+
begin_step_percent=t2i_adapter_field.begin_step_percent,
525+
end_step_percent=t2i_adapter_field.end_step_percent,
526+
resize_mode=t2i_adapter_field.resize_mode,
527+
)
528+
)
529+
502530
def prep_ip_adapter_image_prompts(
503531
self,
504532
context: InvocationContext,
@@ -845,6 +873,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
845873
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
846874
# ext_manager.add_extension(ext)
847875
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
876+
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
848877

849878
# ext: t2i/ip adapter
850879
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
import math
4+
from typing import TYPE_CHECKING, List, Optional, Union
5+
6+
import torch
7+
from diffusers import T2IAdapter
8+
from PIL.Image import Image
9+
10+
from invokeai.app.util.controlnet_utils import prepare_control_image
11+
from invokeai.backend.model_manager import BaseModelType
12+
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
13+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
14+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
15+
16+
if TYPE_CHECKING:
17+
from invokeai.app.invocations.model import ModelIdentifierField
18+
from invokeai.app.services.shared.invocation_context import InvocationContext
19+
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
20+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
21+
22+
23+
class T2IAdapterExt(ExtensionBase):
24+
def __init__(
25+
self,
26+
node_context: InvocationContext,
27+
model_id: ModelIdentifierField,
28+
image: Image,
29+
weight: Union[float, List[float]],
30+
begin_step_percent: float,
31+
end_step_percent: float,
32+
resize_mode: CONTROLNET_RESIZE_VALUES,
33+
):
34+
super().__init__()
35+
self._node_context = node_context
36+
self._model_id = model_id
37+
self._image = image
38+
self._weight = weight
39+
self._resize_mode = resize_mode
40+
self._begin_step_percent = begin_step_percent
41+
self._end_step_percent = end_step_percent
42+
43+
self._adapter_state: Optional[List[torch.Tensor]] = None
44+
45+
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
46+
model_config = self._node_context.models.get_config(self._model_id.key)
47+
if model_config.base == BaseModelType.StableDiffusion1:
48+
self._max_unet_downscale = 8
49+
elif model_config.base == BaseModelType.StableDiffusionXL:
50+
self._max_unet_downscale = 4
51+
else:
52+
raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.")
53+
54+
@callback(ExtensionCallbackType.SETUP)
55+
def setup(self, ctx: DenoiseContext):
56+
t2i_model: T2IAdapter
57+
with self._node_context.models.load(self._model_id) as t2i_model:
58+
_, _, latents_height, latents_width = ctx.inputs.orig_latents.shape
59+
60+
self._adapter_state = self._run_model(
61+
model=t2i_model,
62+
image=self._image,
63+
latents_height=latents_height,
64+
latents_width=latents_width,
65+
)
66+
67+
def _run_model(
68+
self,
69+
model: T2IAdapter,
70+
image: Image,
71+
latents_height: int,
72+
latents_width: int,
73+
):
74+
# Resize the T2I-Adapter input image.
75+
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
76+
# result will match the latent image's dimensions after max_unet_downscale is applied.
77+
input_height = latents_height // self._max_unet_downscale * model.total_downscale_factor
78+
input_width = latents_width // self._max_unet_downscale * model.total_downscale_factor
79+
80+
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
81+
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
82+
# T2I-Adapter model.
83+
#
84+
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
85+
# of the same requirements (e.g. preserving binary masks during resize).
86+
t2i_image = prepare_control_image(
87+
image=image,
88+
do_classifier_free_guidance=False,
89+
width=input_width,
90+
height=input_height,
91+
num_channels=model.config["in_channels"],
92+
device=model.device,
93+
dtype=model.dtype,
94+
resize_mode=self._resize_mode,
95+
)
96+
97+
return model(t2i_image)
98+
99+
@callback(ExtensionCallbackType.PRE_UNET)
100+
def pre_unet_step(self, ctx: DenoiseContext):
101+
# skip if model not active in current step
102+
total_steps = len(ctx.inputs.timesteps)
103+
first_step = math.floor(self._begin_step_percent * total_steps)
104+
last_step = math.ceil(self._end_step_percent * total_steps)
105+
if ctx.step_index < first_step or ctx.step_index > last_step:
106+
return
107+
108+
weight = self._weight
109+
if isinstance(weight, list):
110+
weight = weight[ctx.step_index]
111+
112+
adapter_state = self._adapter_state
113+
if ctx.conditioning_mode == ConditioningMode.Both:
114+
adapter_state = [torch.cat([v] * 2) for v in adapter_state]
115+
116+
if ctx.unet_kwargs.down_intrablock_additional_residuals is None:
117+
ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state]
118+
else:
119+
for i, value in enumerate(adapter_state):
120+
ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight

0 commit comments

Comments
 (0)