11from math import sqrt
2- from typing import Any , Callable , Optional
2+ from typing import Any , Callable , Optional , Tuple
33
44import torch
55import torch .nn as nn
@@ -69,6 +69,17 @@ def forward(
6969 ) -> Tensor :
7070 raise NotImplementedError ()
7171
72+ def inpaint (
73+ self ,
74+ source : Tensor ,
75+ mask : Tensor ,
76+ fn : Callable ,
77+ sigmas : Tensor ,
78+ num_steps : int ,
79+ num_resamples : int ,
80+ ) -> Tensor :
81+ raise NotImplementedError ("Inpainting not available with current sampler" )
82+
7283
7384class KarrasSampler (Sampler ):
7485 """https://arxiv.org/abs/2206.00364 algorithm 1"""
@@ -128,18 +139,22 @@ def forward(
128139class ADPM2Sampler (Sampler ):
129140 """https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
130141
131- """ https://www.desmos.com/calculator/jbxjlqd9mb """
142+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
132143
133144 def __init__ (self , rho : float = 1.0 ):
134145 super ().__init__ ()
135146 self .rho = rho
136147
137- def step (self , x : Tensor , fn : Callable , sigma : float , sigma_next : float ) -> Tensor :
138- # Sigma steps
148+ def get_sigmas (self , sigma : float , sigma_next : float ) -> Tuple [float , float , float ]:
139149 r = self .rho
140150 sigma_up = sqrt (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2 ) / sigma ** 2 )
141151 sigma_down = sqrt (sigma_next ** 2 - sigma_up ** 2 )
142152 sigma_mid = ((sigma ** (1 / r ) + sigma_down ** (1 / r )) / 2 ) ** r
153+ return sigma_up , sigma_down , sigma_mid
154+
155+ def step (self , x : Tensor , fn : Callable , sigma : float , sigma_next : float ) -> Tensor :
156+ # Sigma steps
157+ sigma_up , sigma_down , sigma_mid = self .get_sigmas (sigma , sigma_next )
143158 # Derivative at sigma (∂x/∂sigma)
144159 d = (x - fn (x , sigma = sigma )) / sigma
145160 # Denoise to midpoint
@@ -161,6 +176,31 @@ def forward(
161176 x = self .step (x , fn = fn , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ]) # type: ignore # noqa
162177 return x
163178
179+ def inpaint (
180+ self ,
181+ source : Tensor ,
182+ mask : Tensor ,
183+ fn : Callable ,
184+ sigmas : Tensor ,
185+ num_steps : int ,
186+ num_resamples : int ,
187+ ) -> Tensor :
188+ x = sigmas [0 ] * torch .randn_like (source )
189+
190+ for i in range (num_steps - 1 ):
191+ # Noise source to current noise level
192+ source_noisy = source + sigmas [i ] * torch .randn_like (source )
193+ for r in range (num_resamples ):
194+ # Merge noisy source and current then denoise
195+ x = source_noisy * mask + x * ~ mask
196+ x = self .step (x , fn = fn , sigma = sigmas [i ], sigma_next = sigmas [i + 1 ]) # type: ignore # noqa
197+ # Renoise if not last resample step
198+ if r < num_resamples - 1 :
199+ sigma = sqrt (sigmas [i ] ** 2 - sigmas [i + 1 ] ** 2 )
200+ x = x + sigma * torch .randn_like (x )
201+
202+ return source * mask + x * ~ mask
203+
164204
165205""" Diffusion Classes """
166206
@@ -188,17 +228,16 @@ def __init__(
188228 self .sigma_distribution = sigma_distribution
189229 self .dynamic_threshold = dynamic_threshold
190230
191- def c_skip (self , sigmas : Tensor ) -> Tensor :
192- return (self .sigma_data ** 2 ) / (sigmas ** 2 + self .sigma_data ** 2 )
193-
194- def c_out (self , sigmas : Tensor ) -> Tensor :
195- return sigmas * self .sigma_data * (self .sigma_data ** 2 + sigmas ** 2 ) ** - 0.5
196-
197- def c_in (self , sigmas : Tensor ) -> Tensor :
198- return 1 * (sigmas ** 2 + self .sigma_data ** 2 ) ** - 0.5
199-
200- def c_noise (self , sigmas : Tensor ) -> Tensor :
201- return torch .log (sigmas ) * 0.25
231+ def get_scale_weights (self , sigmas : Tensor ) -> Tuple [Tensor , ...]:
232+ sigma_data = self .sigma_data
233+ sigmas_padded = rearrange (sigmas , "b -> b 1 1" )
234+ c_skip = (sigma_data ** 2 ) / (sigmas_padded ** 2 + sigma_data ** 2 )
235+ c_out = (
236+ sigmas_padded * sigma_data * (sigma_data ** 2 + sigmas_padded ** 2 ) ** - 0.5
237+ )
238+ c_in = (sigmas_padded ** 2 + sigma_data ** 2 ) ** - 0.5
239+ c_noise = torch .log (sigmas ) * 0.25
240+ return c_skip , c_out , c_in , c_noise
202241
203242 def denoise_fn (
204243 self ,
@@ -216,13 +255,10 @@ def denoise_fn(
216255
217256 assert exists (sigmas )
218257
219- sigmas_padded = rearrange (sigmas , "b -> b 1 1" )
220-
221258 # Predict network output and add skip connection
222- x_pred = self .net (self .c_in (sigmas_padded ) * x_noisy , self .c_noise (sigmas ))
223- x_denoised = (
224- self .c_skip (sigmas_padded ) * x_noisy + self .c_out (sigmas_padded ) * x_pred
225- )
259+ c_skip , c_out , c_in , c_noise = self .get_scale_weights (sigmas )
260+ x_pred = self .net (c_in * x_noisy , c_noise )
261+ x_denoised = c_skip * x_noisy + c_out * x_pred
226262
227263 # Dynamic thresholding
228264 if self .dynamic_threshold == 0.0 :
@@ -294,94 +330,32 @@ def forward(self, noise: Tensor, num_steps: Optional[int] = None) -> Tensor:
294330
295331
296332class DiffusionInpainter (nn .Module ):
297- """RePaint Inpainting: https://arxiv.org/abs/2201.09865"""
298-
299333 def __init__ (
300334 self ,
301335 diffusion : Diffusion ,
302336 * ,
303337 num_steps : int ,
304338 num_resamples : int ,
339+ sampler : Sampler ,
305340 sigma_schedule : Schedule ,
306- s_tmin : float = 0 ,
307- s_tmax : float = float ("inf" ),
308- s_churn : float = 0.0 ,
309- s_noise : float = 1.0 ,
310341 ):
311342 super ().__init__ ()
312343 self .denoise_fn = diffusion .denoise_fn
313344 self .num_steps = num_steps
314345 self .num_resamples = num_resamples
346+ self .inpaint_fn = sampler .inpaint
315347 self .sigma_schedule = sigma_schedule
316- self .s_tmin = s_tmin
317- self .s_tmax = s_tmax
318- self .s_noise = s_noise
319- self .s_churn = s_churn
320-
321- def step (
322- self ,
323- x : Tensor ,
324- * ,
325- inpaint : Tensor ,
326- inpaint_mask : Tensor ,
327- sigma : float ,
328- sigma_next : float ,
329- gamma : float ,
330- renoise : bool ,
331- clamp : bool = True ,
332- ) -> Tensor :
333- """Algorithm 2 (step)"""
334- # Select temporarily increased noise level
335- sigma_hat = sigma + gamma * sigma
336- # Noise to move from sigma to sigma_hat
337- epsilon = self .s_noise * torch .randn_like (x )
338- noise = sqrt (sigma_hat ** 2 - sigma ** 2 ) * epsilon
339- # Add increased noise to mixed value
340- x_hat = x * ~ inpaint_mask + inpaint * inpaint_mask + noise
341- # Evaluate ∂x/∂sigma at sigma_hat
342- d = (x_hat - self .denoise_fn (x_hat , sigma = sigma_hat )) / sigma_hat
343- # Take euler step from sigma_hat to sigma_next
344- x_next = x_hat + (sigma_next - sigma_hat ) * d
345- # Second order correction
346- if sigma_next != 0 :
347- model_out_next = self .denoise_fn (x_next , sigma = sigma_next )
348- d_prime = (x_next - model_out_next ) / sigma_next
349- x_next = x_hat + 0.5 * (sigma - sigma_hat ) * (d + d_prime )
350- # Renoise for next resampling step
351- if renoise :
352- x_next = x_next + (sigma - sigma_next ) * torch .randn_like (x_next )
353- return x_next
354348
355349 @torch .no_grad ()
356350 def forward (self , inpaint : Tensor , inpaint_mask : Tensor ) -> Tensor :
357- device = inpaint .device
358- num_steps , num_resamples = self .num_steps , self .num_resamples
359- # Compute sigmas using schedule
360- sigmas = self .sigma_schedule (num_steps , device )
361- # Sample from first sigma distribution
362- x = sigmas [0 ] * torch .randn_like (inpaint )
363- # Compute gammas
364- gammas = torch .where (
365- (sigmas >= self .s_tmin ) & (sigmas <= self .s_tmax ),
366- min (self .s_churn / num_steps , sqrt (2 ) - 1 ),
367- 0.0 ,
351+ x = self .inpaint_fn (
352+ source = inpaint ,
353+ mask = inpaint_mask ,
354+ fn = self .denoise_fn ,
355+ sigmas = self .sigma_schedule (self .num_steps , inpaint .device ),
356+ num_steps = self .num_steps ,
357+ num_resamples = self .num_resamples ,
368358 )
369-
370- for i in range (num_steps - 1 ):
371- for r in range (num_resamples ):
372- x = self .step (
373- x = x ,
374- inpaint = inpaint ,
375- inpaint_mask = inpaint_mask ,
376- sigma = sigmas [i ],
377- sigma_next = sigmas [i + 1 ],
378- gamma = gammas [i ], # type: ignore # noqa
379- renoise = i < num_steps - 1 and r < num_resamples ,
380- )
381-
382- x = x .clamp (- 1.0 , 1.0 )
383- # Make sure inpainting are is same as input
384- x = x * ~ inpaint_mask + inpaint * inpaint_mask
385359 return x
386360
387361
0 commit comments