Skip to content

Commit a5f1069

Browse files
feat: add dynamic thresholding
1 parent 661c392 commit a5f1069

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ from audio_diffusion_pytorch import Diffusion, LogNormalDistribution
7373
diffusion = Diffusion(
7474
net=unet,
7575
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
76-
sigma_data=0.1
76+
sigma_data=0.1,
77+
dynamic_threshold=0.95
7778
)
7879

7980
x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples

audio_diffusion_pytorch/diffusion.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,7 @@ def __init__(
8787
self.s_churn = s_churn
8888

8989
def step(
90-
self,
91-
x: Tensor,
92-
fn: Callable,
93-
sigma: float,
94-
sigma_next: float,
95-
gamma: float,
96-
clamp: bool = True,
90+
self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
9791
) -> Tensor:
9892
"""Algorithm 2 (step)"""
9993
# Select temporarily increased noise level
@@ -102,12 +96,12 @@ def step(
10296
epsilon = self.s_noise * torch.randn_like(x)
10397
x_hat = x + sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
10498
# Evaluate ∂x/∂sigma at sigma_hat
105-
d = (x_hat - fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
99+
d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
106100
# Take euler step from sigma_hat to sigma_next
107101
x_next = x_hat + (sigma_next - sigma_hat) * d
108102
# Second order correction
109103
if sigma_next != 0:
110-
model_out_next = fn(x_next, sigma=sigma_next, clamp=clamp)
104+
model_out_next = fn(x_next, sigma=sigma_next)
111105
d_prime = (x_next - model_out_next) / sigma_next
112106
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
113107
return x_next
@@ -140,25 +134,18 @@ def __init__(self, rho: float = 1.0):
140134
super().__init__()
141135
self.rho = rho
142136

143-
def step(
144-
self,
145-
x: Tensor,
146-
fn: Callable,
147-
sigma: float,
148-
sigma_next: float,
149-
clamp: bool = True,
150-
) -> Tensor:
137+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
151138
# Sigma steps
152139
r = self.rho
153140
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
154141
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
155142
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
156143
# Derivative at sigma (∂x/∂sigma)
157-
d = (x - fn(x, sigma=sigma, clamp=clamp)) / sigma
144+
d = (x - fn(x, sigma=sigma)) / sigma
158145
# Denoise to midpoint
159146
x_mid = x + d * (sigma_mid - sigma)
160147
# Derivative at sigma_mid (∂x_mid/∂sigma_mid)
161-
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid, clamp=clamp)) / sigma_mid
148+
d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
162149
# Denoise to next
163150
x = x + d_mid * (sigma_down - sigma)
164151
# Add randomness
@@ -178,6 +165,11 @@ def forward(
178165
""" Diffusion Classes """
179166

180167

168+
def pad_dims(x: Tensor, ndim: int) -> Tensor:
169+
# Pads additional ndims to the right of the tensor
170+
return x.view(*x.shape, *((1,) * ndim))
171+
172+
181173
class Diffusion(nn.Module):
182174
"""Elucidated Diffusion: https://arxiv.org/abs/2206.00364"""
183175

@@ -187,12 +179,14 @@ def __init__(
187179
*,
188180
sigma_distribution: Distribution,
189181
sigma_data: float, # data distribution standard deviation
182+
dynamic_threshold: float = 0.0,
190183
):
191184
super().__init__()
192185

193186
self.net = net
194187
self.sigma_data = sigma_data
195188
self.sigma_distribution = sigma_distribution
189+
self.dynamic_threshold = dynamic_threshold
196190

197191
def c_skip(self, sigmas: Tensor) -> Tensor:
198192
return (self.sigma_data ** 2) / (sigmas ** 2 + self.sigma_data ** 2)
@@ -211,7 +205,6 @@ def denoise_fn(
211205
x_noisy: Tensor,
212206
sigmas: Optional[Tensor] = None,
213207
sigma: Optional[float] = None,
214-
clamp: bool = False,
215208
) -> Tensor:
216209
batch, device = x_noisy.shape[0], x_noisy.device
217210

@@ -230,9 +223,20 @@ def denoise_fn(
230223
x_denoised = (
231224
self.c_skip(sigmas_padded) * x_noisy + self.c_out(sigmas_padded) * x_pred
232225
)
233-
x_denoised = x_denoised.clamp(-1.0, 1) if clamp else x_denoised
234226

235-
return x_denoised
227+
# Dynamic thresholding
228+
if self.dynamic_threshold == 0.0:
229+
return x_denoised.clamp(-1.0, 1.0)
230+
else:
231+
# Find dynamic threshold quantile for each batch
232+
x_flat = rearrange(x_denoised, "b ... -> b (...)")
233+
scale = torch.quantile(x_flat.abs(), self.dynamic_threshold, dim=-1)
234+
# Clamp to a min of 1.0
235+
scale.clamp_(min=1.0)
236+
# Clamp all values and scale
237+
scale = pad_dims(scale, ndim=x_denoised.ndim - scale.ndim)
238+
x_denoised = x_denoised.clamp(-scale, scale) / scale
239+
return x_denoised
236240

237241
def loss_weight(self, sigmas: Tensor) -> Tensor:
238242
# Computes weight depending on data distribution
@@ -335,12 +339,12 @@ def step(
335339
# Add increased noise to mixed value
336340
x_hat = x * ~inpaint_mask + inpaint * inpaint_mask + noise
337341
# Evaluate ∂x/∂sigma at sigma_hat
338-
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat, clamp=clamp)) / sigma_hat
342+
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat)) / sigma_hat
339343
# Take euler step from sigma_hat to sigma_next
340344
x_next = x_hat + (sigma_next - sigma_hat) * d
341345
# Second order correction
342346
if sigma_next != 0:
343-
model_out_next = self.denoise_fn(x_next, sigma=sigma_next, clamp=clamp)
347+
model_out_next = self.denoise_fn(x_next, sigma=sigma_next)
344348
d_prime = (x_next - model_out_next) / sigma_next
345349
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
346350
# Renoise for next resampling step

audio_diffusion_pytorch/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
use_attention_bottleneck: bool,
3838
diffusion_sigma_distribution: Distribution,
3939
diffusion_sigma_data: int,
40+
diffusion_dynamic_threshold: float,
4041
out_channels: Optional[int] = None,
4142
):
4243
super().__init__()
@@ -66,6 +67,7 @@ def __init__(
6667
net=self.unet,
6768
sigma_distribution=diffusion_sigma_distribution,
6869
sigma_data=diffusion_sigma_data,
70+
dynamic_threshold=diffusion_dynamic_threshold,
6971
)
7072

7173
def forward(self, x: Tensor) -> Tensor:
@@ -105,6 +107,7 @@ def __init__(self, *args, **kwargs):
105107
use_learned_time_embedding=True,
106108
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
107109
diffusion_sigma_data=0.1,
110+
diffusion_dynamic_threshold=0.95,
108111
)
109112
super().__init__(*args, **{**default_kwargs, **kwargs})
110113

0 commit comments

Comments
 (0)