-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Modular backend - inpaint #6643
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
RyanJDick
merged 13 commits into
invoke-ai:main
from
StAlKeR7779:stalker7779/modular_inpaint
Jul 29, 2024
Merged
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
9e7b470
Handle inpaint models
StAlKeR7779 58f3072
Handle inpainting on normal models
StAlKeR7779 5003e5d
Same changes as in other PRs, add check for running inpainting on inp…
StAlKeR7779 87eb018
Revert debug change
StAlKeR7779 9d1fcba
Fix create gradient mask node output
StAlKeR7779 c323a76
Suggested changes
StAlKeR7779 19c0024
Use non-inverted mask generally(except inpaint model handling)
StAlKeR7779 416d29f
Ruff format
StAlKeR7779 bd8890b
Revert "Fix create gradient mask node output"
StAlKeR7779 5810cee
Suggested changes
StAlKeR7779 ed0174f
Suggested changes
StAlKeR7779 84d0288
Revert wrong comment copy
StAlKeR7779 693a3ea
Merge branch 'main' into stalker-modular_inpaint-2
RyanJDick File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
|
||
import einops | ||
import torch | ||
from diffusers import UNet2DConditionModel | ||
|
||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
|
||
|
||
class InpaintExt(ExtensionBase): | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
mask: torch.Tensor, | ||
is_gradient_mask: bool, | ||
): | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().__init__() | ||
self._mask = mask | ||
self._is_gradient_mask = is_gradient_mask | ||
self._noise: Optional[torch.Tensor] = None | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
@staticmethod | ||
def _is_normal_model(unet: UNet2DConditionModel): | ||
return unet.conv_in.in_channels == 4 | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | ||
batch_size = latents.size(0) | ||
mask = einops.repeat(self._mask, "b c h w -> (repeat b) c h w", repeat=batch_size) | ||
if t.dim() == 0: | ||
# some schedulers expect t to be one-dimensional. | ||
# TODO: file diffusers bug about inconsistency? | ||
t = einops.repeat(t, "-> batch", batch=batch_size) | ||
# Noise shouldn't be re-randomized between steps here. The multistep schedulers | ||
# get very confused about what is happening from step to step when we do that. | ||
mask_latents = ctx.scheduler.add_noise(ctx.inputs.orig_latents, self._noise, t) | ||
# TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? | ||
# mask_latents = self.scheduler.scale_model_input(mask_latents, t) | ||
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) | ||
if self._is_gradient_mask: | ||
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
masked_input = torch.where(mask_bool, latents, mask_latents) | ||
else: | ||
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) | ||
return masked_input | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def init_tensors(self, ctx: DenoiseContext): | ||
if not self._is_normal_model(ctx.unet): | ||
raise Exception("InpaintExt should be used only on normal models!") | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) | ||
|
||
self._noise = ctx.inputs.noise | ||
if self._noise is None: | ||
self._noise = torch.randn( | ||
ctx.latents.shape, | ||
dtype=torch.float32, | ||
device="cpu", | ||
generator=torch.Generator(device="cpu").manual_seed(ctx.seed), | ||
).to(device=ctx.latents.device, dtype=ctx.latents.dtype) | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# TODO: order value | ||
@callback(ExtensionCallbackType.PRE_STEP, order=-100) | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def apply_mask_to_initial_latents(self, ctx: DenoiseContext): | ||
ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep) | ||
|
||
# TODO: order value | ||
# TODO: redo this with preview events rewrite | ||
@callback(ExtensionCallbackType.POST_STEP, order=-100) | ||
def apply_mask_to_step_output(self, ctx: DenoiseContext): | ||
timestep = ctx.scheduler.timesteps[-1] | ||
if hasattr(ctx.step_output, "denoised"): | ||
ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep) | ||
elif hasattr(ctx.step_output, "pred_original_sample"): | ||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep) | ||
else: | ||
ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep) | ||
|
||
# TODO: should here be used order? | ||
# restore unmasked part after the last step is completed | ||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP) | ||
def restore_unmasked(self, ctx: DenoiseContext): | ||
if self._is_gradient_mask: | ||
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) | ||
else: | ||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) |
66 changes: 66 additions & 0 deletions
66
invokeai/backend/stable_diffusion/extensions/inpaint_model.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Optional | ||
|
||
import torch | ||
from diffusers import UNet2DConditionModel | ||
|
||
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType | ||
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback | ||
|
||
if TYPE_CHECKING: | ||
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext | ||
|
||
|
||
class InpaintModelExt(ExtensionBase): | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
mask: Optional[torch.Tensor], | ||
masked_latents: Optional[torch.Tensor], | ||
is_gradient_mask: bool, | ||
): | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
super().__init__() | ||
if mask is not None and masked_latents is None: | ||
raise ValueError("Source image required for inpaint mask when inpaint model used!") | ||
|
||
self._mask = mask | ||
self._masked_latents = masked_latents | ||
self._is_gradient_mask = is_gradient_mask | ||
|
||
@staticmethod | ||
def _is_inpaint_model(unet: UNet2DConditionModel): | ||
return unet.conv_in.in_channels == 9 | ||
|
||
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP) | ||
def init_tensors(self, ctx: DenoiseContext): | ||
if not self._is_inpaint_model(ctx.unet): | ||
raise Exception("InpaintModelExt should be used only on inpaint models!") | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if self._mask is None: | ||
self._mask = torch.ones_like(ctx.latents[:1, :1]) | ||
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) | ||
|
||
if self._masked_latents is None: | ||
self._masked_latents = torch.zeros_like(ctx.latents[:1]) | ||
self._masked_latents = self._masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) | ||
|
||
# TODO: any ideas about order value? | ||
# do last so that other extensions works with normal latents | ||
@callback(ExtensionCallbackType.PRE_UNET, order=1000) | ||
def append_inpaint_layers(self, ctx: DenoiseContext): | ||
batch_size = ctx.unet_kwargs.sample.shape[0] | ||
b_mask = torch.cat([self._mask] * batch_size) | ||
b_masked_latents = torch.cat([self._masked_latents] * batch_size) | ||
ctx.unet_kwargs.sample = torch.cat( | ||
[ctx.unet_kwargs.sample, b_mask, b_masked_latents], | ||
dim=1, | ||
) | ||
|
||
# TODO: should here be used order? | ||
# restore unmasked part as inpaint model can change unmasked part slightly | ||
@callback(ExtensionCallbackType.POST_DENOISE_LOOP) | ||
def restore_unmasked(self, ctx: DenoiseContext): | ||
if self._is_gradient_mask: | ||
ctx.latents = torch.where(self._mask > 0, ctx.latents, ctx.inputs.orig_latents) | ||
StAlKeR7779 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
ctx.latents = torch.lerp(ctx.inputs.orig_latents, ctx.latents, self._mask) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.