Skip to content

Commit c323a76

Browse files
Suggested changes
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 9d1fcba commit c323a76

File tree

3 files changed

+68
-23
lines changed

3 files changed

+68
-23
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -732,10 +732,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
732732
dtype = TorchDevice.choose_torch_dtype()
733733

734734
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)
735-
latents = latents.to(device=device, dtype=dtype)
736-
if noise is not None:
737-
noise = noise.to(device=device, dtype=dtype)
738-
739735
_, _, latent_height, latent_width = latents.shape
740736

741737
conditioning_data = self.get_conditioning_data(
@@ -768,21 +764,6 @@ def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
768764
denoising_end=self.denoising_end,
769765
)
770766

771-
denoise_ctx = DenoiseContext(
772-
inputs=DenoiseInputs(
773-
orig_latents=latents,
774-
timesteps=timesteps,
775-
init_timestep=init_timestep,
776-
noise=noise,
777-
seed=seed,
778-
scheduler_step_kwargs=scheduler_step_kwargs,
779-
conditioning_data=conditioning_data,
780-
attention_processor_cls=CustomAttnProcessor2_0,
781-
),
782-
unet=None,
783-
scheduler=scheduler,
784-
)
785-
786767
# get the unet's config so that we can pass the base to sd_step_callback()
787768
unet_config = context.models.get_config(self.unet.unet.key)
788769

@@ -799,6 +780,26 @@ def step_callback(state: PipelineIntermediateState) -> None:
799780
elif mask is not None:
800781
ext_manager.add_extension(InpaintExt(mask, is_gradient_mask))
801782

783+
# Initialize context for modular denoise
784+
latents = latents.to(device=device, dtype=dtype)
785+
if noise is not None:
786+
noise = noise.to(device=device, dtype=dtype)
787+
788+
denoise_ctx = DenoiseContext(
789+
inputs=DenoiseInputs(
790+
orig_latents=latents,
791+
timesteps=timesteps,
792+
init_timestep=init_timestep,
793+
noise=noise,
794+
seed=seed,
795+
scheduler_step_kwargs=scheduler_step_kwargs,
796+
conditioning_data=conditioning_data,
797+
attention_processor_cls=CustomAttnProcessor2_0,
798+
),
799+
unet=None,
800+
scheduler=scheduler,
801+
)
802+
802803
# ext: t2i/ip adapter
803804
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
804805

invokeai/backend/stable_diffusion/extensions/inpaint.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,40 @@
1414

1515

1616
class InpaintExt(ExtensionBase):
17+
"""An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
18+
models.
19+
"""
1720
def __init__(
1821
self,
1922
mask: torch.Tensor,
2023
is_gradient_mask: bool,
2124
):
25+
"""Initialize InpaintExt.
26+
Args:
27+
mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28+
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
29+
inpainted.
30+
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
31+
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
32+
1.
33+
"""
2234
super().__init__()
2335
self._mask = mask
2436
self._is_gradient_mask = is_gradient_mask
37+
38+
# Noise, which used to noisify unmasked part of image
39+
# if noise provided to context, then it will be used
40+
# if no noise provided, then noise will be generated based on seed
2541
self._noise: Optional[torch.Tensor] = None
2642

2743
@staticmethod
2844
def _is_normal_model(unet: UNet2DConditionModel):
45+
""" Checks if the provided UNet belongs to a regular model.
46+
The `in_channels` of a UNet vary depending on model type:
47+
- normal - 4
48+
- depth - 5
49+
- inpaint - 9
50+
"""
2951
return unet.conv_in.in_channels == 4
3052

3153
def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
@@ -42,8 +64,8 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
4264
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
4365
mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size)
4466
if self._is_gradient_mask:
45-
threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps
46-
mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
67+
threshold = (t.item()) / ctx.scheduler.config.num_train_timesteps
68+
mask_bool = mask > threshold
4769
masked_input = torch.where(mask_bool, latents, mask_latents)
4870
else:
4971
masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
@@ -52,11 +74,13 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
5274
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
5375
def init_tensors(self, ctx: DenoiseContext):
5476
if not self._is_normal_model(ctx.unet):
55-
raise Exception("InpaintExt should be used only on normal models!")
77+
raise ValueError("InpaintExt should be used only on normal models!")
5678

5779
self._mask = self._mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype)
5880

5981
self._noise = ctx.inputs.noise
82+
# 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
83+
# We still need noise for inpainting, so we generate it from the seed here.
6084
if self._noise is None:
6185
self._noise = torch.randn(
6286
ctx.latents.shape,

invokeai/backend/stable_diffusion/extensions/inpaint_model.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,26 @@
1313

1414

1515
class InpaintModelExt(ExtensionBase):
16+
"""An extension for inpainting with inpainting models. See `InpaintExt` for inpainting with non-inpainting
17+
models.
18+
"""
1619
def __init__(
1720
self,
1821
mask: Optional[torch.Tensor],
1922
masked_latents: Optional[torch.Tensor],
2023
is_gradient_mask: bool,
2124
):
25+
"""Initialize InpaintModelExt.
26+
Args:
27+
mask (Optional[torch.Tensor]): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28+
expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
29+
inpainted.
30+
masked_latents (Optional[torch.Tensor]): Latents of initial image, with masked out by black color inpainted area.
31+
If mask provided, then too should be provided. Shape: (1, 1, latent_height, latent_width)
32+
is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
33+
from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
34+
1.
35+
"""
2236
super().__init__()
2337
if mask is not None and masked_latents is None:
2438
raise ValueError("Source image required for inpaint mask when inpaint model used!")
@@ -29,12 +43,18 @@ def __init__(
2943

3044
@staticmethod
3145
def _is_inpaint_model(unet: UNet2DConditionModel):
46+
""" Checks if the provided UNet belongs to a regular model.
47+
The `in_channels` of a UNet vary depending on model type:
48+
- normal - 4
49+
- depth - 5
50+
- inpaint - 9
51+
"""
3252
return unet.conv_in.in_channels == 9
3353

3454
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
3555
def init_tensors(self, ctx: DenoiseContext):
3656
if not self._is_inpaint_model(ctx.unet):
37-
raise Exception("InpaintModelExt should be used only on inpaint models!")
57+
raise ValueError("InpaintModelExt should be used only on inpaint models!")
3858

3959
if self._mask is None:
4060
self._mask = torch.ones_like(ctx.latents[:1, :1])

0 commit comments

Comments
 (0)