Skip to content

Commit 6992cc9

Browse files
feat: update upsampler with proper resampling method, randomize in-batch resampling with multiple factors
1 parent da94c00 commit 6992cc9

File tree

3 files changed

+73
-14
lines changed

3 files changed

+73
-14
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import random
21
from typing import Any, Optional, Sequence, Tuple, Union
32

43
import torch
@@ -15,7 +14,7 @@
1514
Schedule,
1615
)
1716
from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d
18-
from .utils import default, exists, to_list
17+
from .utils import default, downsample, exists, to_list, upsample
1918

2019
"""
2120
Diffusion Classes (generic for 1d data)
@@ -68,29 +67,41 @@ class DiffusionUpsampler1d(Model1d):
6867
def __init__(
6968
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
7069
):
71-
self.factor = to_list(factor)
70+
self.factors = to_list(factor)
7271
default_kwargs = dict(
7372
in_channels=in_channels,
7473
context_channels=[in_channels],
7574
)
7675
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
7776

78-
def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
79-
# Either user provides factor or we pick one at random
80-
factor = default(factor, random.choice(self.factor))
81-
# Downsample by picking every `factor` item
82-
downsampled = x[:, :, ::factor]
83-
# Upsample by interleaving to get context
84-
channels = torch.repeat_interleave(downsampled, repeats=factor, dim=2)
77+
def random_reupsample(self, x: Tensor) -> Tensor:
78+
batch_size, factors = x.shape[0], self.factors
79+
# Pick random factor for each batch element
80+
factor_batch_idx = torch.randint(0, len(factors), (batch_size,))
81+
82+
for i, factor in enumerate(factors):
83+
# Pick random items with current factor, skip if 0
84+
n = torch.count_nonzero(factor_batch_idx == i)
85+
if n > 0:
86+
waveforms = x[factor_batch_idx == i]
87+
# Downsample and reupsample items
88+
downsampled = downsample(waveforms, factor=factor)
89+
reupsampled = upsample(downsampled, factor=factor)
90+
# Save reupsampled version in place
91+
x[factor_batch_idx == i] = reupsampled
92+
return x
93+
94+
def forward(self, x: Tensor, **kwargs) -> Tensor:
95+
channels = self.random_reupsample(x)
8596
return self.diffusion(x, channels_list=[channels], **kwargs)
8697

8798
def sample( # type: ignore
8899
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
89100
):
90101
# Either user provides factor or we pick the first
91-
factor = default(factor, self.factor[0])
92-
# Upsample channels by interleaving
93-
channels = torch.repeat_interleave(undersampled, repeats=factor, dim=2)
102+
factor = default(factor, self.factors[0])
103+
# Upsample channels
104+
channels = upsample(undersampled, factor=factor)
94105
noise = torch.randn_like(channels)
95106
default_kwargs = dict(channels_list=[channels])
96107
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore

audio_diffusion_pytorch/utils.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
import math
12
from functools import reduce
23
from inspect import isfunction
34
from typing import Callable, List, Optional, Sequence, TypeVar, Union
45

6+
import torch
7+
import torch.nn.functional as F
8+
from einops import rearrange
9+
from torch import Tensor
510
from typing_extensions import TypeGuard
611

712
T = TypeVar("T")
@@ -35,3 +40,46 @@ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
3540

3641
def prod(vals: Sequence[int]) -> int:
3742
return reduce(lambda x, y: x * y, vals)
43+
44+
45+
"""
46+
DSP Utils
47+
"""
48+
49+
50+
def resample(
51+
waveforms: Tensor,
52+
factor_in: int,
53+
factor_out: int,
54+
rolloff: float = 0.99,
55+
lowpass_filter_width: int = 6,
56+
) -> Tensor:
57+
"""Resamples a waveform using sinc interpolation, adapted from torchaudio"""
58+
b, _, length = waveforms.shape
59+
length_target = int(factor_out * length / factor_in)
60+
d = dict(device=waveforms.device, dtype=waveforms.dtype)
61+
62+
base_factor = min(factor_in, factor_out) * rolloff
63+
width = math.ceil(lowpass_filter_width * factor_in / base_factor)
64+
idx = torch.arange(-width, width + factor_in, **d)[None, None] / factor_in # type: ignore # noqa
65+
t = torch.arange(0, -factor_out, step=-1, **d)[:, None, None] / factor_out + idx # type: ignore # noqa
66+
t = (t * base_factor).clamp(-lowpass_filter_width, lowpass_filter_width) * math.pi
67+
68+
window = torch.cos(t / lowpass_filter_width / 2) ** 2
69+
scale = base_factor / factor_in
70+
kernels = torch.where(t == 0, torch.tensor(1.0).to(t), t.sin() / t)
71+
kernels *= window * scale
72+
73+
waveforms = rearrange(waveforms, "b c t -> (b c) t")
74+
waveforms = F.pad(waveforms, (width, width + factor_in))
75+
resampled = F.conv1d(waveforms[:, None], kernels, stride=factor_in)
76+
resampled = rearrange(resampled, "(b c) k l -> b c (l k)", b=b)
77+
return resampled[..., :length_target]
78+
79+
80+
def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
81+
return resample(waveforms, factor_in=factor, factor_out=1, **kwargs)
82+
83+
84+
def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
85+
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.48",
6+
version="0.0.49",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)