@@ -87,13 +87,7 @@ def __init__(
8787 self .s_churn = s_churn
8888
8989 def step (
90- self ,
91- x : Tensor ,
92- fn : Callable ,
93- sigma : float ,
94- sigma_next : float ,
95- gamma : float ,
96- clamp : bool = True ,
90+ self , x : Tensor , fn : Callable , sigma : float , sigma_next : float , gamma : float
9791 ) -> Tensor :
9892 """Algorithm 2 (step)"""
9993 # Select temporarily increased noise level
@@ -102,12 +96,12 @@ def step(
10296 epsilon = self .s_noise * torch .randn_like (x )
10397 x_hat = x + sqrt (sigma_hat ** 2 - sigma ** 2 ) * epsilon
10498 # Evaluate ∂x/∂sigma at sigma_hat
105- d = (x_hat - fn (x_hat , sigma = sigma_hat , clamp = clamp )) / sigma_hat
99+ d = (x_hat - fn (x_hat , sigma = sigma_hat )) / sigma_hat
106100 # Take euler step from sigma_hat to sigma_next
107101 x_next = x_hat + (sigma_next - sigma_hat ) * d
108102 # Second order correction
109103 if sigma_next != 0 :
110- model_out_next = fn (x_next , sigma = sigma_next , clamp = clamp )
104+ model_out_next = fn (x_next , sigma = sigma_next )
111105 d_prime = (x_next - model_out_next ) / sigma_next
112106 x_next = x_hat + 0.5 * (sigma - sigma_hat ) * (d + d_prime )
113107 return x_next
@@ -140,25 +134,18 @@ def __init__(self, rho: float = 1.0):
140134 super ().__init__ ()
141135 self .rho = rho
142136
143- def step (
144- self ,
145- x : Tensor ,
146- fn : Callable ,
147- sigma : float ,
148- sigma_next : float ,
149- clamp : bool = True ,
150- ) -> Tensor :
137+ def step (self , x : Tensor , fn : Callable , sigma : float , sigma_next : float ) -> Tensor :
151138 # Sigma steps
152139 r = self .rho
153140 sigma_up = sqrt (sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2 ) / sigma ** 2 )
154141 sigma_down = sqrt (sigma_next ** 2 - sigma_up ** 2 )
155142 sigma_mid = ((sigma ** (1 / r ) + sigma_down ** (1 / r )) / 2 ) ** r
156143 # Derivative at sigma (∂x/∂sigma)
157- d = (x - fn (x , sigma = sigma , clamp = clamp )) / sigma
144+ d = (x - fn (x , sigma = sigma )) / sigma
158145 # Denoise to midpoint
159146 x_mid = x + d * (sigma_mid - sigma )
160147 # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
161- d_mid = (x_mid - fn (x_mid , sigma = sigma_mid , clamp = clamp )) / sigma_mid
148+ d_mid = (x_mid - fn (x_mid , sigma = sigma_mid )) / sigma_mid
162149 # Denoise to next
163150 x = x + d_mid * (sigma_down - sigma )
164151 # Add randomness
@@ -178,6 +165,11 @@ def forward(
178165""" Diffusion Classes """
179166
180167
168+ def pad_dims (x : Tensor , ndim : int ) -> Tensor :
169+ # Pads additional ndims to the right of the tensor
170+ return x .view (* x .shape , * ((1 ,) * ndim ))
171+
172+
181173class Diffusion (nn .Module ):
182174 """Elucidated Diffusion: https://arxiv.org/abs/2206.00364"""
183175
@@ -187,12 +179,14 @@ def __init__(
187179 * ,
188180 sigma_distribution : Distribution ,
189181 sigma_data : float , # data distribution standard deviation
182+ dynamic_threshold : float = 0.0 ,
190183 ):
191184 super ().__init__ ()
192185
193186 self .net = net
194187 self .sigma_data = sigma_data
195188 self .sigma_distribution = sigma_distribution
189+ self .dynamic_threshold = dynamic_threshold
196190
197191 def c_skip (self , sigmas : Tensor ) -> Tensor :
198192 return (self .sigma_data ** 2 ) / (sigmas ** 2 + self .sigma_data ** 2 )
@@ -211,7 +205,6 @@ def denoise_fn(
211205 x_noisy : Tensor ,
212206 sigmas : Optional [Tensor ] = None ,
213207 sigma : Optional [float ] = None ,
214- clamp : bool = False ,
215208 ) -> Tensor :
216209 batch , device = x_noisy .shape [0 ], x_noisy .device
217210
@@ -230,9 +223,20 @@ def denoise_fn(
230223 x_denoised = (
231224 self .c_skip (sigmas_padded ) * x_noisy + self .c_out (sigmas_padded ) * x_pred
232225 )
233- x_denoised = x_denoised .clamp (- 1.0 , 1 ) if clamp else x_denoised
234226
235- return x_denoised
227+ # Dynamic thresholding
228+ if self .dynamic_threshold == 0.0 :
229+ return x_denoised .clamp (- 1.0 , 1.0 )
230+ else :
231+ # Find dynamic threshold quantile for each batch
232+ x_flat = rearrange (x_denoised , "b ... -> b (...)" )
233+ scale = torch .quantile (x_flat .abs (), self .dynamic_threshold , dim = - 1 )
234+ # Clamp to a min of 1.0
235+ scale .clamp_ (min = 1.0 )
236+ # Clamp all values and scale
237+ scale = pad_dims (scale , ndim = x_denoised .ndim - scale .ndim )
238+ x_denoised = x_denoised .clamp (- scale , scale ) / scale
239+ return x_denoised
236240
237241 def loss_weight (self , sigmas : Tensor ) -> Tensor :
238242 # Computes weight depending on data distribution
@@ -335,12 +339,12 @@ def step(
335339 # Add increased noise to mixed value
336340 x_hat = x * ~ inpaint_mask + inpaint * inpaint_mask + noise
337341 # Evaluate ∂x/∂sigma at sigma_hat
338- d = (x_hat - self .denoise_fn (x_hat , sigma = sigma_hat , clamp = clamp )) / sigma_hat
342+ d = (x_hat - self .denoise_fn (x_hat , sigma = sigma_hat )) / sigma_hat
339343 # Take euler step from sigma_hat to sigma_next
340344 x_next = x_hat + (sigma_next - sigma_hat ) * d
341345 # Second order correction
342346 if sigma_next != 0 :
343- model_out_next = self .denoise_fn (x_next , sigma = sigma_next , clamp = clamp )
347+ model_out_next = self .denoise_fn (x_next , sigma = sigma_next )
344348 d_prime = (x_next - model_out_next ) / sigma_next
345349 x_next = x_hat + 0.5 * (sigma - sigma_hat ) * (d + d_prime )
346350 # Renoise for next resampling step
0 commit comments