Skip to content

Commit bcbb510

Browse files
feat: add new v-inpainter
1 parent a16e835 commit bcbb510

File tree

4 files changed

+96
-1
lines changed

4 files changed

+96
-1
lines changed

README.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,37 @@ latent = autoencoder.encode(audio) # Encode
170170
sample = autoencoder.decode(latent, num_steps=10) # Decode by sampling diffusion model conditioning on latent
171171
```
172172

173+
## Other
174+
175+
### Inpainting
176+
```py
177+
from audio_diffusion_pytorch import UNetV0, VInpainter
178+
179+
# The diffusion UNetV0 (this is an example, the net must be trained to work)
180+
net = UNetV0(
181+
dim=1,
182+
in_channels=2, # U-Net: number of input/output (audio) channels
183+
channels=[8, 32, 64, 128, 256, 512, 512, 1024, 1024], # U-Net: channels at each layer
184+
factors=[1, 4, 4, 4, 2, 2, 2, 2, 2], # U-Net: downsampling and upsampling factors at each layer
185+
items=[1, 2, 2, 2, 2, 2, 2, 4, 4], # U-Net: number of repeating items at each layer
186+
attentions=[0, 0, 0, 0, 0, 1, 1, 1, 1], # U-Net: attention enabled/disabled at each layer
187+
attention_heads=8, # U-Net: number of attention heads per attention block
188+
attention_features=64, # U-Net: number of attention features per attention block,
189+
)
190+
191+
# Instantiate inpainter with trained net
192+
inpainter = VInpainter(net=net)
193+
194+
# Inpaint source
195+
y = inpainter(
196+
source=torch.randn(1, 2, 2**18), # Start source
197+
mask=torch.randint(0, 2, (1, 2, 2 ** 18), dtype=torch.bool), # Set to `True` the parts you want to keep
198+
num_steps=10, # Number of inpainting steps
199+
num_resamples=2, # Number of resampling steps
200+
show_progress=True,
201+
) # [1, 2, 2 ** 18]
202+
```
203+
173204
## Appreciation
174205

175206
* [StabilityAI](https://stability.ai/) for the compute, [Zach Evans](https://github.com/zqevans) and everyone else from [HarmonAI](https://www.harmonai.org/) for the interesting research discussions.

audio_diffusion_pytorch/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Schedule,
88
UniformDistribution,
99
VDiffusion,
10+
VInpainter,
1011
VSampler,
1112
)
1213
from .models import (
@@ -15,4 +16,5 @@
1516
DiffusionModel,
1617
DiffusionUpsampler,
1718
DiffusionVocoder,
19+
EncoderBase,
1820
)

audio_diffusion_pytorch/diffusion.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from torch import Tensor
99
from tqdm import tqdm
1010

11+
from .utils import default
12+
1113
""" Distributions """
1214

1315

@@ -166,6 +168,7 @@ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
166168
alpha, beta = torch.cos(angle), torch.sin(angle)
167169
return alpha, beta
168170

171+
@torch.no_grad()
169172
def forward( # type: ignore
170173
self, x_noisy: Tensor, num_steps: int, show_progress: bool = False, **kwargs
171174
) -> Tensor:
@@ -242,6 +245,7 @@ def sample_start(self, num_items: int, num_steps: int, **kwargs) -> Tensor:
242245
# Sample start
243246
return self.sample_loop(current=noise, sigmas=sigmas, **kwargs)
244247

248+
@torch.no_grad()
245249
def forward(
246250
self,
247251
num_items: int,
@@ -289,3 +293,61 @@ def forward(
289293
chunks += [torch.randn(shape, device=self.device)]
290294

291295
return torch.cat(chunks[:num_chunks], dim=-1)
296+
297+
298+
""" Inpainters """
299+
300+
301+
class Inpainter(nn.Module):
302+
pass
303+
304+
305+
class VInpainter(Inpainter):
306+
307+
diffusion_types = [VDiffusion]
308+
309+
def __init__(self, net: nn.Module, schedule: Schedule = LinearSchedule()):
310+
super().__init__()
311+
self.net = net
312+
self.schedule = schedule
313+
314+
def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
315+
angle = sigmas * pi / 2
316+
alpha, beta = torch.cos(angle), torch.sin(angle)
317+
return alpha, beta
318+
319+
@torch.no_grad()
320+
def forward( # type: ignore
321+
self,
322+
source: Tensor,
323+
mask: Tensor,
324+
num_steps: int,
325+
num_resamples: int,
326+
show_progress: bool = False,
327+
x_noisy: Optional[Tensor] = None,
328+
**kwargs,
329+
) -> Tensor:
330+
x_noisy = default(x_noisy, lambda: torch.randn_like(source))
331+
b = x_noisy.shape[0]
332+
sigmas = self.schedule(num_steps + 1, device=x_noisy.device)
333+
sigmas = repeat(sigmas, "i -> i b", b=b)
334+
sigmas_batch = extend_dim(sigmas, dim=x_noisy.ndim + 1)
335+
alphas, betas = self.get_alpha_beta(sigmas_batch)
336+
progress_bar = tqdm(range(num_steps), disable=not show_progress)
337+
338+
for i in progress_bar:
339+
for r in range(num_resamples):
340+
v_pred = self.net(x_noisy, sigmas[i], **kwargs)
341+
x_pred = alphas[i] * x_noisy - betas[i] * v_pred
342+
noise_pred = betas[i] * x_noisy + alphas[i] * v_pred
343+
# Renoise to current noise level if resampling
344+
j = r == num_resamples - 1
345+
x_noisy = alphas[i + j] * x_pred + betas[i + j] * noise_pred
346+
s_noisy = alphas[i + j] * source + betas[i + j] * torch.randn_like(
347+
source
348+
)
349+
x_noisy = s_noisy * mask + x_noisy * ~mask
350+
351+
progress_bar.set_description(f"Inpainting (noise={sigmas[i+1,0]:.2f})")
352+
353+
return x_noisy

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.1.2",
6+
version="0.1.3",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)