14
14
15
15
16
16
class InpaintExt (ExtensionBase ):
17
+ """An extension for inpainting with non-inpainting models. See `InpaintModelExt` for inpainting with inpainting
18
+ models.
19
+ """
17
20
def __init__ (
18
21
self ,
19
22
mask : torch .Tensor ,
20
23
is_gradient_mask : bool ,
21
24
):
25
+ """Initialize InpaintExt.
26
+ Args:
27
+ mask (torch.Tensor): The inpainting mask. Shape: (1, 1, latent_height, latent_width). Values are
28
+ expected to be in the range [0, 1]. A value of 0 means that the corresponding 'pixel' should not be
29
+ inpainted.
30
+ is_gradient_mask (bool): If True, mask is interpreted as a gradient mask meaning that the mask values range
31
+ from 0 to 1. If False, mask is interpreted as binary mask meaning that the mask values are either 0 or
32
+ 1.
33
+ """
22
34
super ().__init__ ()
23
35
self ._mask = mask
24
36
self ._is_gradient_mask = is_gradient_mask
37
+
38
+ # Noise, which used to noisify unmasked part of image
39
+ # if noise provided to context, then it will be used
40
+ # if no noise provided, then noise will be generated based on seed
25
41
self ._noise : Optional [torch .Tensor ] = None
26
42
27
43
@staticmethod
28
44
def _is_normal_model (unet : UNet2DConditionModel ):
45
+ """ Checks if the provided UNet belongs to a regular model.
46
+ The `in_channels` of a UNet vary depending on model type:
47
+ - normal - 4
48
+ - depth - 5
49
+ - inpaint - 9
50
+ """
29
51
return unet .conv_in .in_channels == 4
30
52
31
53
def _apply_mask (self , ctx : DenoiseContext , latents : torch .Tensor , t : torch .Tensor ) -> torch .Tensor :
@@ -42,8 +64,8 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
42
64
# mask_latents = self.scheduler.scale_model_input(mask_latents, t)
43
65
mask_latents = einops .repeat (mask_latents , "b c h w -> (repeat b) c h w" , repeat = batch_size )
44
66
if self ._is_gradient_mask :
45
- threshhold = (t .item ()) / ctx .scheduler .config .num_train_timesteps
46
- mask_bool = mask > threshhold # I don't know when mask got inverted, but it did
67
+ threshold = (t .item ()) / ctx .scheduler .config .num_train_timesteps
68
+ mask_bool = mask > threshold
47
69
masked_input = torch .where (mask_bool , latents , mask_latents )
48
70
else :
49
71
masked_input = torch .lerp (mask_latents .to (dtype = latents .dtype ), latents , mask .to (dtype = latents .dtype ))
@@ -52,11 +74,13 @@ def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tenso
52
74
@callback (ExtensionCallbackType .PRE_DENOISE_LOOP )
53
75
def init_tensors (self , ctx : DenoiseContext ):
54
76
if not self ._is_normal_model (ctx .unet ):
55
- raise Exception ("InpaintExt should be used only on normal models!" )
77
+ raise ValueError ("InpaintExt should be used only on normal models!" )
56
78
57
79
self ._mask = self ._mask .to (device = ctx .latents .device , dtype = ctx .latents .dtype )
58
80
59
81
self ._noise = ctx .inputs .noise
82
+ # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner).
83
+ # We still need noise for inpainting, so we generate it from the seed here.
60
84
if self ._noise is None :
61
85
self ._noise = torch .randn (
62
86
ctx .latents .shape ,
0 commit comments