Skip to content

Commit e4c118f

Browse files
feat: option to condition upsampling with used factor
1 parent aaaa699 commit e4c118f

File tree

3 files changed

+51
-13
lines changed

3 files changed

+51
-13
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313
Sampler,
1414
Schedule,
1515
)
16-
from .modules import Bottleneck, MultiEncoder1d, UNet1d, UNetConditional1d
16+
from .modules import (
17+
Bottleneck,
18+
MultiEncoder1d,
19+
SinusoidalEmbedding,
20+
UNet1d,
21+
UNetConditional1d,
22+
)
1723
from .utils import default, downsample, exists, to_list, upsample
1824

1925
"""
@@ -65,46 +71,64 @@ def sample(
6571

6672
class DiffusionUpsampler1d(Model1d):
6773
def __init__(
68-
self, factor: Union[int, Sequence[int]], in_channels: int, *args, **kwargs
74+
self,
75+
in_channels: int,
76+
factor: Union[int, Sequence[int]],
77+
factor_features: Optional[int] = None,
78+
*args,
79+
**kwargs
6980
):
7081
self.factors = to_list(factor)
82+
self.use_conditioning = exists(factor_features)
83+
7184
default_kwargs = dict(
7285
in_channels=in_channels,
7386
context_channels=[in_channels],
87+
context_features=factor_features if self.use_conditioning else None,
7488
)
7589
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
7690

77-
def random_reupsample(self, x: Tensor) -> Tensor:
91+
if self.use_conditioning:
92+
assert exists(factor_features)
93+
self.to_features = SinusoidalEmbedding(dim=factor_features)
94+
95+
def random_reupsample(self, x: Tensor) -> Tuple[Tensor, Tensor]:
7896
batch_size, factors = x.shape[0], self.factors
7997
# Pick random factor for each batch element
80-
factor_batch_idx = torch.randint(0, len(factors), (batch_size,))
98+
random_factors = torch.randint(0, len(factors), (batch_size,))
8199
x = x.clone()
82100

83101
for i, factor in enumerate(factors):
84102
# Pick random items with current factor, skip if 0
85-
n = torch.count_nonzero(factor_batch_idx == i)
103+
n = torch.count_nonzero(random_factors == i)
86104
if n > 0:
87-
waveforms = x[factor_batch_idx == i]
105+
waveforms = x[random_factors == i]
88106
# Downsample and reupsample items
89107
downsampled = downsample(waveforms, factor=factor)
90108
reupsampled = upsample(downsampled, factor=factor)
91109
# Save reupsampled version in place
92-
x[factor_batch_idx == i] = reupsampled
93-
return x
110+
x[random_factors == i] = reupsampled
111+
return x, random_factors
94112

95113
def forward(self, x: Tensor, **kwargs) -> Tensor:
96-
channels = self.random_reupsample(x)
97-
return self.diffusion(x, channels_list=[channels], **kwargs)
114+
channels, factors = self.random_reupsample(x)
115+
features = self.to_features(factors) if self.use_conditioning else None
116+
return self.diffusion(x, channels_list=[channels], features=features, **kwargs)
98117

99118
def sample( # type: ignore
100119
self, undersampled: Tensor, factor: Optional[int] = None, *args, **kwargs
101120
):
102121
# Either user provides factor or we pick the first
122+
batch_size, device = undersampled.shape[0], undersampled.device
103123
factor = default(factor, self.factors[0])
104-
# Upsample channels
124+
# Upsample channels by interpolation
105125
channels = upsample(undersampled, factor=factor)
126+
# Compute features if conditioning on factor
127+
factors = torch.tensor([factor] * batch_size, device=device)
128+
features = self.to_features(factors) if self.use_conditioning else None
129+
# Diffuse upsampled
106130
noise = torch.randn_like(channels)
107-
default_kwargs = dict(channels_list=[channels])
131+
default_kwargs = dict(channels_list=[channels], features=features)
108132
return super().sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
109133

110134

audio_diffusion_pytorch/modules.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
from math import pi
23
from typing import Any, List, Optional, Sequence, Tuple, Union
34

@@ -488,6 +489,19 @@ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
488489
"""
489490

490491

492+
class SinusoidalEmbedding(nn.Module):
493+
def __init__(self, dim: int):
494+
super().__init__()
495+
self.dim = dim
496+
497+
def forward(self, x: Tensor) -> Tensor:
498+
device, half_dim = x.device, self.dim // 2
499+
emb = torch.tensor(math.log(10000) / (half_dim - 1), device=device)
500+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
501+
emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
502+
return torch.cat((emb.sin(), emb.cos()), dim=-1)
503+
504+
491505
class LearnedPositionalEmbedding(nn.Module):
492506
"""Used for continuous time"""
493507

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.57",
6+
version="0.0.58",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)