Skip to content

Commit 85549a6

Browse files
feat: add ADPM2 inpainting with repaint
1 parent d0145f2 commit 85549a6

File tree

3 files changed

+74
-105
lines changed

3 files changed

+74
-105
lines changed

README.md

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ loss.backward() # Do this many times
2727
noise = torch.randn(2, 1, 2 ** 18)
2828
sampled = model.sample(
2929
noise=noise,
30-
num_steps=5 # Suggested range: 2-50
30+
num_steps=5 # Suggested range: 2-100
3131
) # [2, 1, 262144]
3232
```
3333

@@ -88,7 +88,7 @@ from audio_diffusion_pytorch import DiffusionSampler, KarrasSchedule
8888

8989
sampler = DiffusionSampler(
9090
diffusion,
91-
num_steps=5, # Suggested range 1-100, higher better quality but takes longer
91+
num_steps=5, # Suggested range 2-100, higher better quality but takes longer
9292
sampler=ADPM2Sampler(rho=1),
9393
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0)
9494
)
@@ -98,20 +98,15 @@ y = sampler(noise = torch.randn(1,1,2 ** 18))
9898

9999
#### Inpainting
100100

101-
Note: this is fixed to the `KarrasSampler`, needs to be updated to custom sampler.
102-
103101
```py
104-
from audio_diffusion_pytorch import DiffusionInpainter, KarrasSchedule
102+
from audio_diffusion_pytorch import DiffusionInpainter, KarrasSchedule, ADPM2Sampler
105103

106104
inpainter = DiffusionInpainter(
107105
diffusion,
108-
num_steps=50, # Suggested range 32-1000, higher for better quality
109-
num_resamples=5, # Suggested range 1-10, higher for better quality
106+
num_steps=5, # Suggested range 2-100, higher for better quality
107+
num_resamples=1, # Suggested range 1-10, higher for better quality
110108
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
111-
s_tmin=0,
112-
s_tmax=10,
113-
s_churn=40,
114-
s_noise=1.003
109+
sampler=ADPM2Sampler(rho=1.0),
115110
)
116111

117112
inpaint = torch.randn(1,1,2 ** 18) # Start track, e.g. one sampled with DiffusionSampler
@@ -147,7 +142,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
147142
- [x] Add ancestral DPM2 sampler.
148143
- [x] Add dynamic thresholding.
149144
- [x] Add (variational) autoencoder option to compress audio before diffusion.
150-
- [ ] Fix inpainting and make it work with ADPM2 sampler.
145+
- [x] Fix inpainting and make it work with ADPM2 sampler.
151146

152147
## Appreciation
153148

audio_diffusion_pytorch/diffusion.py

Lines changed: 66 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from math import sqrt
2-
from typing import Any, Callable, Optional
2+
from typing import Any, Callable, Optional, Tuple
33

44
import torch
55
import torch.nn as nn
@@ -69,6 +69,17 @@ def forward(
6969
) -> Tensor:
7070
raise NotImplementedError()
7171

72+
def inpaint(
73+
self,
74+
source: Tensor,
75+
mask: Tensor,
76+
fn: Callable,
77+
sigmas: Tensor,
78+
num_steps: int,
79+
num_resamples: int,
80+
) -> Tensor:
81+
raise NotImplementedError("Inpainting not available with current sampler")
82+
7283

7384
class KarrasSampler(Sampler):
7485
"""https://arxiv.org/abs/2206.00364 algorithm 1"""
@@ -128,18 +139,22 @@ def forward(
128139
class ADPM2Sampler(Sampler):
129140
"""https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py"""
130141

131-
""" https://www.desmos.com/calculator/jbxjlqd9mb """
142+
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
132143

133144
def __init__(self, rho: float = 1.0):
134145
super().__init__()
135146
self.rho = rho
136147

137-
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
138-
# Sigma steps
148+
def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
139149
r = self.rho
140150
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
141151
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
142152
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
153+
return sigma_up, sigma_down, sigma_mid
154+
155+
def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
156+
# Sigma steps
157+
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
143158
# Derivative at sigma (∂x/∂sigma)
144159
d = (x - fn(x, sigma=sigma)) / sigma
145160
# Denoise to midpoint
@@ -161,6 +176,31 @@ def forward(
161176
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
162177
return x
163178

179+
def inpaint(
180+
self,
181+
source: Tensor,
182+
mask: Tensor,
183+
fn: Callable,
184+
sigmas: Tensor,
185+
num_steps: int,
186+
num_resamples: int,
187+
) -> Tensor:
188+
x = sigmas[0] * torch.randn_like(source)
189+
190+
for i in range(num_steps - 1):
191+
# Noise source to current noise level
192+
source_noisy = source + sigmas[i] * torch.randn_like(source)
193+
for r in range(num_resamples):
194+
# Merge noisy source and current then denoise
195+
x = source_noisy * mask + x * ~mask
196+
x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
197+
# Renoise if not last resample step
198+
if r < num_resamples - 1:
199+
sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
200+
x = x + sigma * torch.randn_like(x)
201+
202+
return source * mask + x * ~mask
203+
164204

165205
""" Diffusion Classes """
166206

@@ -188,17 +228,16 @@ def __init__(
188228
self.sigma_distribution = sigma_distribution
189229
self.dynamic_threshold = dynamic_threshold
190230

191-
def c_skip(self, sigmas: Tensor) -> Tensor:
192-
return (self.sigma_data ** 2) / (sigmas ** 2 + self.sigma_data ** 2)
193-
194-
def c_out(self, sigmas: Tensor) -> Tensor:
195-
return sigmas * self.sigma_data * (self.sigma_data ** 2 + sigmas ** 2) ** -0.5
196-
197-
def c_in(self, sigmas: Tensor) -> Tensor:
198-
return 1 * (sigmas ** 2 + self.sigma_data ** 2) ** -0.5
199-
200-
def c_noise(self, sigmas: Tensor) -> Tensor:
201-
return torch.log(sigmas) * 0.25
231+
def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
232+
sigma_data = self.sigma_data
233+
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
234+
c_skip = (sigma_data ** 2) / (sigmas_padded ** 2 + sigma_data ** 2)
235+
c_out = (
236+
sigmas_padded * sigma_data * (sigma_data ** 2 + sigmas_padded ** 2) ** -0.5
237+
)
238+
c_in = (sigmas_padded ** 2 + sigma_data ** 2) ** -0.5
239+
c_noise = torch.log(sigmas) * 0.25
240+
return c_skip, c_out, c_in, c_noise
202241

203242
def denoise_fn(
204243
self,
@@ -216,13 +255,10 @@ def denoise_fn(
216255

217256
assert exists(sigmas)
218257

219-
sigmas_padded = rearrange(sigmas, "b -> b 1 1")
220-
221258
# Predict network output and add skip connection
222-
x_pred = self.net(self.c_in(sigmas_padded) * x_noisy, self.c_noise(sigmas))
223-
x_denoised = (
224-
self.c_skip(sigmas_padded) * x_noisy + self.c_out(sigmas_padded) * x_pred
225-
)
259+
c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
260+
x_pred = self.net(c_in * x_noisy, c_noise)
261+
x_denoised = c_skip * x_noisy + c_out * x_pred
226262

227263
# Dynamic thresholding
228264
if self.dynamic_threshold == 0.0:
@@ -294,94 +330,32 @@ def forward(self, noise: Tensor, num_steps: Optional[int] = None) -> Tensor:
294330

295331

296332
class DiffusionInpainter(nn.Module):
297-
"""RePaint Inpainting: https://arxiv.org/abs/2201.09865"""
298-
299333
def __init__(
300334
self,
301335
diffusion: Diffusion,
302336
*,
303337
num_steps: int,
304338
num_resamples: int,
339+
sampler: Sampler,
305340
sigma_schedule: Schedule,
306-
s_tmin: float = 0,
307-
s_tmax: float = float("inf"),
308-
s_churn: float = 0.0,
309-
s_noise: float = 1.0,
310341
):
311342
super().__init__()
312343
self.denoise_fn = diffusion.denoise_fn
313344
self.num_steps = num_steps
314345
self.num_resamples = num_resamples
346+
self.inpaint_fn = sampler.inpaint
315347
self.sigma_schedule = sigma_schedule
316-
self.s_tmin = s_tmin
317-
self.s_tmax = s_tmax
318-
self.s_noise = s_noise
319-
self.s_churn = s_churn
320-
321-
def step(
322-
self,
323-
x: Tensor,
324-
*,
325-
inpaint: Tensor,
326-
inpaint_mask: Tensor,
327-
sigma: float,
328-
sigma_next: float,
329-
gamma: float,
330-
renoise: bool,
331-
clamp: bool = True,
332-
) -> Tensor:
333-
"""Algorithm 2 (step)"""
334-
# Select temporarily increased noise level
335-
sigma_hat = sigma + gamma * sigma
336-
# Noise to move from sigma to sigma_hat
337-
epsilon = self.s_noise * torch.randn_like(x)
338-
noise = sqrt(sigma_hat ** 2 - sigma ** 2) * epsilon
339-
# Add increased noise to mixed value
340-
x_hat = x * ~inpaint_mask + inpaint * inpaint_mask + noise
341-
# Evaluate ∂x/∂sigma at sigma_hat
342-
d = (x_hat - self.denoise_fn(x_hat, sigma=sigma_hat)) / sigma_hat
343-
# Take euler step from sigma_hat to sigma_next
344-
x_next = x_hat + (sigma_next - sigma_hat) * d
345-
# Second order correction
346-
if sigma_next != 0:
347-
model_out_next = self.denoise_fn(x_next, sigma=sigma_next)
348-
d_prime = (x_next - model_out_next) / sigma_next
349-
x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
350-
# Renoise for next resampling step
351-
if renoise:
352-
x_next = x_next + (sigma - sigma_next) * torch.randn_like(x_next)
353-
return x_next
354348

355349
@torch.no_grad()
356350
def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
357-
device = inpaint.device
358-
num_steps, num_resamples = self.num_steps, self.num_resamples
359-
# Compute sigmas using schedule
360-
sigmas = self.sigma_schedule(num_steps, device)
361-
# Sample from first sigma distribution
362-
x = sigmas[0] * torch.randn_like(inpaint)
363-
# Compute gammas
364-
gammas = torch.where(
365-
(sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
366-
min(self.s_churn / num_steps, sqrt(2) - 1),
367-
0.0,
351+
x = self.inpaint_fn(
352+
source=inpaint,
353+
mask=inpaint_mask,
354+
fn=self.denoise_fn,
355+
sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
356+
num_steps=self.num_steps,
357+
num_resamples=self.num_resamples,
368358
)
369-
370-
for i in range(num_steps - 1):
371-
for r in range(num_resamples):
372-
x = self.step(
373-
x=x,
374-
inpaint=inpaint,
375-
inpaint_mask=inpaint_mask,
376-
sigma=sigmas[i],
377-
sigma_next=sigmas[i + 1],
378-
gamma=gammas[i], # type: ignore # noqa
379-
renoise=i < num_steps - 1 and r < num_resamples,
380-
)
381-
382-
x = x.clamp(-1.0, 1.0)
383-
# Make sure inpainting are is same as input
384-
x = x * ~inpaint_mask + inpaint * inpaint_mask
385359
return x
386360

387361

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.13",
6+
version="0.0.14",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)