Skip to content

Commit fa3c0c8

Browse files
Merge branch 'main' into stalker7779/fix_gradient_mask
2 parents eef88d1 + 66547b9 commit fa3c0c8

File tree

16 files changed

+564
-132
lines changed

16 files changed

+564
-132
lines changed

docker/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \
5555
FROM node:20-slim AS web-builder
5656
ENV PNPM_HOME="/pnpm"
5757
ENV PATH="$PNPM_HOME:$PATH"
58+
RUN corepack use [email protected]
5859
RUN corepack enable
5960

6061
WORKDIR /build

invokeai/app/invocations/denoise_latents.py

Lines changed: 74 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
from invokeai.app.util.controlnet_utils import prepare_control_image
3838
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
3939
from invokeai.backend.lora import LoRAModelRaw
40-
from invokeai.backend.model_manager import BaseModelType
40+
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
4141
from invokeai.backend.model_patcher import ModelPatcher
42-
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
42+
from invokeai.backend.stable_diffusion import PipelineIntermediateState
4343
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
4444
from invokeai.backend.stable_diffusion.diffusers_pipeline import (
4545
ControlNetData,
@@ -60,8 +60,12 @@
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6161
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
63+
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
64+
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
6365
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6466
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
67+
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
68+
from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt
6569
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6670
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
6771
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -498,6 +502,33 @@ def parse_controlnet_field(
498502
)
499503
)
500504

505+
@staticmethod
506+
def parse_t2i_adapter_field(
507+
exit_stack: ExitStack,
508+
context: InvocationContext,
509+
t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
510+
ext_manager: ExtensionsManager,
511+
) -> None:
512+
if t2i_adapters is None:
513+
return
514+
515+
# Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField.
516+
if isinstance(t2i_adapters, T2IAdapterField):
517+
t2i_adapters = [t2i_adapters]
518+
519+
for t2i_adapter_field in t2i_adapters:
520+
ext_manager.add_extension(
521+
T2IAdapterExt(
522+
node_context=context,
523+
model_id=t2i_adapter_field.t2i_adapter_model,
524+
image=context.images.get_pil(t2i_adapter_field.image.image_name),
525+
weight=t2i_adapter_field.weight,
526+
begin_step_percent=t2i_adapter_field.begin_step_percent,
527+
end_step_percent=t2i_adapter_field.end_step_percent,
528+
resize_mode=t2i_adapter_field.resize_mode,
529+
)
530+
)
531+
501532
def prep_ip_adapter_image_prompts(
502533
self,
503534
context: InvocationContext,
@@ -707,7 +738,7 @@ def prep_inpaint_mask(
707738
else:
708739
masked_latents = torch.where(mask < 0.5, 0.0, latents)
709740

710-
return 1 - mask, masked_latents, self.denoise_mask.gradient
741+
return mask, masked_latents, self.denoise_mask.gradient
711742

712743
@staticmethod
713744
def prepare_noise_and_latents(
@@ -765,10 +796,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
765796
dtype = TorchDevice.choose_torch_dtype()
766797

767798
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
768-
latents = latents.to(device=device, dtype=dtype)
769-
if noise is not None:
770-
noise = noise.to(device=device, dtype=dtype)
771-
772799
_, _, latent_height, latent_width = latents.shape
773800

774801
conditioning_data = self.get_conditioning_data(
@@ -801,21 +828,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
801828
denoising_end=self.denoising_end,
802829
)
803830

804-
denoise_ctx = DenoiseContext(
805-
inputs=DenoiseInputs(
806-
orig_latents=latents,
807-
timesteps=timesteps,
808-
init_timestep=init_timestep,
809-
noise=noise,
810-
seed=seed,
811-
scheduler_step_kwargs=scheduler_step_kwargs,
812-
conditioning_data=conditioning_data,
813-
attention_processor_cls=CustomAttnProcessor2_0,
814-
),
815-
unet=None,
816-
scheduler=scheduler,
817-
)
818-
819831
# get the unet's config so that we can pass the base to sd_step_callback()
820832
unet_config = context.models.get_config(self.unet.unet.key)
821833

@@ -833,13 +845,48 @@ def step_callback(state: PipelineIntermediateState) -> None:
833845
if self.unet.freeu_config:
834846
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
835847

848+
### seamless
849+
if self.unet.seamless_axes:
850+
ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes))
851+
852+
### inpaint
853+
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
854+
# NOTE: We used to identify inpainting models by inpecting the shape of the loaded UNet model weights. Now we
855+
# use the ModelVariantType config. During testing, there was a report of a user with models that had an
856+
# incorrect ModelVariantType value. Re-installing the model fixed the issue. If this issue turns out to be
857+
# prevalent, we will have to revisit how we initialize the inpainting extensions.
858+
if unet_config.variant == ModelVariantType.Inpaint:
859+
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
860+
elif mask is not None:
861+
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
862+
863+
# Initialize context for modular denoise
864+
latents = latents.to(device=device, dtype=dtype)
865+
if noise is not None:
866+
noise = noise.to(device=device, dtype=dtype)
867+
denoise_ctx = DenoiseContext(
868+
inputs=DenoiseInputs(
869+
orig_latents=latents,
870+
timesteps=timesteps,
871+
init_timestep=init_timestep,
872+
noise=noise,
873+
seed=seed,
874+
scheduler_step_kwargs=scheduler_step_kwargs,
875+
conditioning_data=conditioning_data,
876+
attention_processor_cls=CustomAttnProcessor2_0,
877+
),
878+
unet=None,
879+
scheduler=scheduler,
880+
)
881+
836882
# context for loading additional models
837883
with ExitStack() as exit_stack:
838884
# later should be smth like:
839885
# for extension_field in self.extensions:
840886
# ext = extension_field.to_extension(exit_stack, context, ext_manager)
841887
# ext_manager.add_extension(ext)
842888
self.parse_controlnet_field(exit_stack, context, self.control, ext_manager)
889+
self.parse_t2i_adapter_field(exit_stack, context, self.t2i_adapter, ext_manager)
843890

844891
# ext: t2i/ip adapter
845892
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
@@ -871,6 +918,10 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
871918
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
872919

873920
mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
921+
# At this point, the mask ranges from 0 (leave unchanged) to 1 (inpaint).
922+
# We invert the mask here for compatibility with the old backend implementation.
923+
if mask is not None:
924+
mask = 1 - mask
874925

875926
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
876927
# below. Investigate whether this is appropriate.
@@ -915,7 +966,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
915966
ExitStack() as exit_stack,
916967
unet_info.model_on_device() as (model_state_dict, unet),
917968
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
918-
set_seamless(unet, self.unet.seamless_axes), # FIXME
969+
SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME
919970
# Apply the LoRA after unet has been moved to its target device for faster patching.
920971
ModelPatcher.apply_lora_unet(
921972
unet,

invokeai/app/invocations/latents_to_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from invokeai.app.invocations.model import VAEField
2525
from invokeai.app.invocations.primitives import ImageOutput
2626
from invokeai.app.services.shared.invocation_context import InvocationContext
27-
from invokeai.backend.stable_diffusion import set_seamless
27+
from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt
2828
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
2929
from invokeai.backend.util.devices import TorchDevice
3030

@@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
5959

6060
vae_info = context.models.load(self.vae.vae)
6161
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
62-
with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae:
62+
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
6363
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
6464
latents = latents.to(vae.device)
6565
if self.fp32:

invokeai/backend/stable_diffusion/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77
StableDiffusionGeneratorPipeline,
88
)
99
from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401
10-
from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401
1110

1211
__all__ = [
1312
"PipelineIntermediateState",
1413
"StableDiffusionGeneratorPipeline",
1514
"InvokeAIDiffuserComponent",
16-
"set_seamless",
1715
]
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Optional
4+
5+
import einops
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
9+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
10+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
11+
12+
if TYPE_CHECKING:
13+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
14+
15+
16+
class InpaintExt(ExtensionBase):
17+
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
18+
models.
19+
"""
20+
21+
def __init__(
22+
self,
23+
mask: torch.Tensor,
24+
is_gradient_mask: bool,
25+
):
26+
"""Initialize InpaintExt.
27+
Args:
28+
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
29+
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
30+
inpainted.
31+
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
32+
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
33+
1.
34+
"""
35+
super().__init__()
36+
self._mask = mask
37+
self._is_gradient_mask = is_gradient_mask
38+
39+
# Noise, which used to noisify unmasked part of image
40+
# if noise provided to context, then it will be used
41+
# if no noise provided, then noise will be generated based on seed
42+
self._noise: Optional[torch.Tensor] = None
43+
44+
@staticmethod
45+
def _is_normal_model(unet: UNet2DConditionModel):
46+
"""Checks if the provided UNet belongs to a regular model.
47+
The `in_channels` of a UNet vary depending on model type:
48+
- normal - 4
49+
- depth - 5
50+
- inpaint - 9
51+
"""
52+
return unet.conv_in.in_channels == 4
53+
54+
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
55+
batch_size = latents.size(0)
56+
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
57+
if t.dim() == 0:
58+
# some schedulers expect t to be one-dimensional.
59+
# TODO: file diffusers bug about inconsistency?
60+
t = einops.repeat(t, "-> batch", batch=batch_size)
61+
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
62+
# get very confused about what is happening from step to step when we do that.
63+
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
64+
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
65+
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
66+
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
67+
if self._is_gradient_mask:
68+
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
69+
mask_bool = mask < 1 - threshold
70+
masked_input = torch.where(mask_bool, latents, mask_latents)
71+
else:
72+
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
73+
return masked_input
74+
75+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
76+
def init_tensors(self, ctx: DenoiseContext):
77+
if not self._is_normal_model(ctx.unet):
78+
raise ValueError(
79+
"InpaintExt should be used only on normal (non-inpainting) models. This could be caused by an "
80+
"inpainting model that was incorrectly marked as a non-inpainting model. In some cases, this can be "
81+
"fixed by removing and re-adding the model (so that it gets re-probed)."
82+
)
83+
84+
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
85+
86+
self._noise = ctx.inputs.noise
87+
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
88+
# We still need noise for inpainting, so we generate it from the seed here.
89+
if self._noise is None:
90+
self._noise = torch.randn(
91+
ctx.latents.shape,
92+
dtype=torch.float32,
93+
device="cpu",
94+
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
95+
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)
96+
97+
# Use negative order to make extensions with default order work with patched latents
98+
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
99+
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
100+
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)
101+
102+
# TODO: redo this with preview events rewrite
103+
# Use negative order to make extensions with default order work with patched latents
104+
@callback(ExtensionCallbackType.POST_STEP, order=-100)
105+
def apply_mask_to_step_output(self, ctx: DenoiseContext):
106+
timestep = ctx.scheduler.timesteps[-1]
107+
if hasattr(ctx.step_output, "denoised"):
108+
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
109+
elif hasattr(ctx.step_output, "pred_original_sample"):
110+
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
111+
else:
112+
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)
113+
114+
# Restore unmasked part after the last step is completed
115+
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
116+
def restore_unmasked(self, ctx: DenoiseContext):
117+
if self._is_gradient_mask:
118+
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
119+
else:
120+
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)

0 commit comments

Comments
 (0)