Skip to content

Commit 5bf1837

Browse files
feat: add noise conditioning, move cfg, add unet1dall
1 parent c937e49 commit 5bf1837

File tree

4 files changed

+75
-16
lines changed

4 files changed

+75
-16
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@
3535
DiffusionVocoder1d,
3636
Model1d,
3737
)
38-
from .modules import NumberEmbedder, T5Embedder, UNet1d, UNetConditional1d
38+
from .modules import NumberEmbedder, T5Embedder, UNet1d

audio_diffusion_pytorch/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tqdm import tqdm
1010

1111
from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
12-
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool
12+
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetCFG1d, rand_bool
1313
from .utils import (
1414
closest_power_2,
1515
default,
@@ -34,7 +34,7 @@ def __init__(
3434
super().__init__()
3535
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
3636

37-
UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
37+
UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d
3838
self.unet = UNet(**kwargs)
3939

4040
self.diffusion = XDiffusion(

audio_diffusion_pytorch/modules.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from einops_exts import rearrange_many
99
from torch import Tensor, einsum
1010

11-
from .utils import closest_power_2, default, exists, groupby
11+
from .utils import closest_power_2, default, exists, groupby, is_sequence
1212

1313
"""
1414
Utils
@@ -909,9 +909,11 @@ def __init__(
909909
self.use_stft = use_stft
910910
self.use_stft_context = use_stft_context
911911

912+
self.context_features = context_features
912913
context_channels_pad_length = num_layers + 1 - len(context_channels)
913914
context_channels = context_channels + [0] * context_channels_pad_length
914915
self.context_channels = context_channels
916+
self.context_embedding_features = context_embedding_features
915917

916918
if use_context_channels:
917919
has_context = [c > 0 for c in context_channels]
@@ -1140,22 +1142,21 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
11401142
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
11411143

11421144

1143-
class UNetConditional1d(UNet1d):
1144-
"""
1145-
UNet1d with classifier-free guidance on the token embeddings
1146-
"""
1145+
class UNetCFG1d(UNet1d):
1146+
1147+
"""UNet1d with Classifier-Free Guidance"""
11471148

11481149
def __init__(
11491150
self,
1150-
context_embedding_features: int,
11511151
context_embedding_max_length: int,
1152+
context_embedding_features: int,
11521153
**kwargs,
11531154
):
11541155
super().__init__(
11551156
context_embedding_features=context_embedding_features, **kwargs
11561157
)
11571158
self.fixed_embedding = FixedEmbedding(
1158-
context_embedding_max_length, context_embedding_features
1159+
max_length=context_embedding_max_length, features=context_embedding_features
11591160
)
11601161

11611162
def forward( # type: ignore
@@ -1178,14 +1179,72 @@ def forward( # type: ignore
11781179
)
11791180
embedding = torch.where(batch_mask, fixed_embedding, embedding)
11801181

1181-
out = super().forward(x, time, embedding=embedding, **kwargs)
1182-
11831182
if embedding_scale != 1.0:
1184-
# Scale conditional output using classifier-free guidance
1183+
# Compute both normal and fixed embedding outputs
1184+
out = super().forward(x, time, embedding=embedding, **kwargs)
11851185
out_masked = super().forward(x, time, embedding=fixed_embedding, **kwargs)
1186-
out = out_masked + (out - out_masked) * embedding_scale
1186+
# Scale conditional output using classifier-free guidance
1187+
return out_masked + (out - out_masked) * embedding_scale
1188+
else:
1189+
return super().forward(x, time, embedding=embedding, **kwargs)
1190+
1191+
1192+
class UNetNCCA1d(UNet1d):
1193+
1194+
"""UNet1d with Noise Channel Conditioning Augmentation"""
1195+
1196+
def __init__(self, context_features: int, **kwargs):
1197+
super().__init__(context_features=context_features, **kwargs)
1198+
self.embedder = NumberEmbedder(features=context_features)
1199+
1200+
def forward( # type: ignore
1201+
self,
1202+
x: Tensor,
1203+
time: Tensor,
1204+
*,
1205+
channels_list: Sequence[Tensor],
1206+
channels_augmentation: bool = False,
1207+
channels_scale: Union[int, Sequence[int]] = 0,
1208+
**kwargs,
1209+
) -> Tensor:
1210+
b, num_items = x.shape[0], len(channels_list)
1211+
1212+
if channels_augmentation:
1213+
# Random noise augmentation for each item
1214+
channels_scale = torch.rand(num_items, b).to(x) # type: ignore
1215+
for i in range(num_items):
1216+
item = channels_list[i]
1217+
scale = rearrange(channels_scale[i], "b -> b 1 1") # type: ignore
1218+
channels_list[i] = torch.randn_like(item) * scale + item * (1 - scale) # type: ignore # noqa
1219+
else:
1220+
# Expand same scale to each batch element
1221+
if is_sequence(channels_scale):
1222+
assert_message = "len(channels_scale) must match len(channels_list)"
1223+
assert len(channels_scale) == num_items, assert_message
1224+
else:
1225+
channels_scale = num_items * [channels_scale] # type: ignore
1226+
channels_scale = torch.tensor(channels_scale).to(x) # type: ignore
1227+
channels_scale = repeat(channels_scale, "n -> n b", b=b)
1228+
1229+
# Compute scale feature embedding
1230+
scale_embedding = self.embedder(channels_scale)
1231+
scale_embedding = reduce(scale_embedding, "n b d -> b d", "sum")
1232+
1233+
return super().forward(
1234+
x=x,
1235+
time=time,
1236+
channels_list=channels_list,
1237+
features=scale_embedding,
1238+
**kwargs,
1239+
)
1240+
1241+
1242+
class UNetAll1d(UNetCFG1d, UNetNCCA1d):
1243+
def __init__(self, *args, **kwargs):
1244+
super().__init__(*args, **kwargs)
11871245

1188-
return out
1246+
def forward(self, *args, **kwargs): # type: ignore
1247+
return UNetCFG1d.forward(self, *args, **kwargs)
11891248

11901249

11911250
class T5Embedder(nn.Module):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.93",
6+
version="0.0.94",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)