Skip to content
Merged
1 change: 1 addition & 0 deletions invokeai/app/invocations/create_gradient_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:

# redistribute blur so that the original edges are 0 and blur outwards to 1
blur_tensor = (blur_tensor - 0.5) * 2
blur_tensor[blur_tensor < 0] = 0.0

threshold = 1 - self.minimum_denoise

Expand Down
42 changes: 27 additions & 15 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager import BaseModelType
from invokeai.backend.model_manager import BaseModelType, ModelVariantType
from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs
Expand All @@ -58,6 +58,8 @@
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt
from invokeai.backend.stable_diffusion.extensions.inpaint_model import InpaintModelExt
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
Expand Down Expand Up @@ -672,7 +674,7 @@ def prep_inpaint_mask(
else:
masked_latents = torch.where(mask < 0.5, 0.0, latents)

return 1 - mask, masked_latents, self.denoise_mask.gradient
return mask, masked_latents, self.denoise_mask.gradient

@staticmethod
def prepare_noise_and_latents(
Expand Down Expand Up @@ -730,10 +732,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
dtype = TorchDevice.choose_torch_dtype()

seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

_, _, latent_height, latent_width = latents.shape

conditioning_data = self.get_conditioning_data(
Expand Down Expand Up @@ -766,6 +764,27 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
denoising_end=self.denoising_end,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

### inpaint
mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents)
if unet_config.variant == ModelVariantType.Inpaint:
ext_manager.add_extension(InpaintModelExt(mask, masked_latents, is_gradient_mask))
elif mask is not None:
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))

# Initialize context for modular denoise
latents = latents.to(device=device, dtype=dtype)
if noise is not None:
noise = noise.to(device=device, dtype=dtype)

denoise_ctx = DenoiseContext(
inputs=DenoiseInputs(
orig_latents=latents,
Expand All @@ -781,15 +800,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
scheduler=scheduler,
)

# get the unet's config so that we can pass the base to sd_step_callback()
unet_config = context.models.get_config(self.unet.unet.key)

### preview
def step_callback(state: PipelineIntermediateState) -> None:
context.util.sd_step_callback(state, unet_config.base)

ext_manager.add_extension(PreviewExt(step_callback))

# ext: t2i/ip adapter
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)

Expand Down Expand Up @@ -820,6 +830,8 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
if mask is not None:
mask = 1 - mask

# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
Expand Down
116 changes: 116 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/inpaint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import einops
import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext


class InpaintExt(ExtensionBase):
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
models.
"""
def __init__(
self,
mask: torch.Tensor,
is_gradient_mask: bool,
):
"""Initialize InpaintExt.
Args:
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted.
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__()
self._mask = mask
self._is_gradient_mask = is_gradient_mask

# Noise, which used to noisify unmasked part of image
# if noise provided to context, then it will be used
# if no noise provided, then noise will be generated based on seed
self._noise: Optional[torch.Tensor] = None

@staticmethod
def _is_normal_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 4

def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
batch_size = latents.size(0)
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size)
if t.dim() == 0:
# some schedulers expect t to be one-dimensional.
# TODO: file diffusers bug about inconsistency?
t = einops.repeat(t, "-> batch", batch=batch_size)
# Noise shouldn't be re-randomized between steps here. The multistep schedulers
# get very confused about what is happening from step to step when we do that.
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t)
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
if self._is_gradient_mask:
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
mask_bool = mask < 1 - threshold
masked_input = torch.where(mask_bool, latents, mask_latents)
else:
masked_input = torch.lerp(latents, mask_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype))
return masked_input

@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext):
if not self._is_normal_model(ctx.unet):
raise ValueError("InpaintExt should be used only on normal models!")

self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

self._noise = ctx.inputs.noise
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
# We still need noise for inpainting, so we generate it from the seed here.
if self._noise is None:
self._noise = torch.randn(
ctx.latents.shape,
dtype=torch.float32,
device="cpu",
generator=torch.Generator(device="cpu").manual_seed(ctx.seed),
).to(device=ctx.latents.device, dtype=ctx.latents.dtype)

# TODO: order value
@callback(ExtensionCallbackType.PRE_STEP, order=-100)
def apply_mask_to_initial_latents(self, ctx: DenoiseContext):
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep)

# TODO: order value
# TODO: redo this with preview events rewrite
@callback(ExtensionCallbackType.POST_STEP, order=-100)
def apply_mask_to_step_output(self, ctx: DenoiseContext):
timestep = ctx.scheduler.timesteps[-1]
if hasattr(ctx.step_output, "denoised"):
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep)
elif hasattr(ctx.step_output, "pred_original_sample"):
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep)
else:
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep)

# TODO: should here be used order?
# restore unmasked part after the last step is completed
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask:
ctx.latents = torch.where(self._mask < 1, ctx.latents, ctx.inputs.orig_latents)
else:
ctx.latents = torch.lerp(ctx.latents, ctx.inputs.orig_latents, self._mask)
89 changes: 89 additions & 0 deletions invokeai/backend/stable_diffusion/extensions/inpaint_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

import torch
from diffusers import UNet2DConditionModel

from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback

if TYPE_CHECKING:
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext


class InpaintModelExt(ExtensionBase):
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
models.
"""
def __init__(
self,
mask: Optional[torch.Tensor],
masked_latents: Optional[torch.Tensor],
is_gradient_mask: bool,
):
"""Initialize InpaintModelExt.
Args:
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
expected to be in the range [0, 1]. A value of 1 means that the corresponding 'pixel' should not be
inpainted.
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
1.
"""
super().__init__()
if mask is not None and masked_latents is None:
raise ValueError("Source image required for inpaint mask when inpaint model used!")

# Inverse mask, because inpaint models treat mask as: 0 - remain same, 1 - inpaint
self._mask = None
if mask is not None:
self._mask = 1 - mask
self._masked_latents = masked_latents
self._is_gradient_mask = is_gradient_mask

@staticmethod
def _is_inpaint_model(unet: UNet2DConditionModel):
""" Checks if the provided UNet belongs to a regular model.
The `in_channels` of a UNet vary depending on model type:
- normal - 4
- depth - 5
- inpaint - 9
"""
return unet.conv_in.in_channels == 9

@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
def init_tensors(self, ctx: DenoiseContext):
if not self._is_inpaint_model(ctx.unet):
raise ValueError("InpaintModelExt should be used only on inpaint models!")

if self._mask is None:
self._mask = torch.ones_like(ctx.latents[:1, :1])
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

if self._masked_latents is None:
self._masked_latents = torch.zeros_like(ctx.latents[:1])
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype)

# TODO: any ideas about order value?
# do last so that other extensions works with normal latents
@callback(ExtensionCallbackType.PRE_UNET, order=1000)
def append_inpaint_layers(self, ctx: DenoiseContext):
batch_size = ctx.unet_kwargs.sample.shape[0]
b_mask = torch.cat([self._mask] * batch_size)
b_masked_latents = torch.cat([self._masked_latents] * batch_size)
ctx.unet_kwargs.sample = torch.cat(
[ctx.unet_kwargs.sample, b_mask, b_masked_latents],
dim=1,
)

# TODO: should here be used order?
# restore unmasked part as inpaint model can change unmasked part slightly
@callback(ExtensionCallbackType.POST_DENOISE_LOOP)
def restore_unmasked(self, ctx: DenoiseContext):
if self._is_gradient_mask:
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents)
else:
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask)