Skip to content

Commit aaaa699

Browse files
feat: add magnitude channels option, change to quantile norm
1 parent 330c60d commit aaaa699

File tree

4 files changed

+33
-39
lines changed

4 files changed

+33
-39
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def get_default_model_kwargs():
220220
use_nearest_upsample=False,
221221
use_skip_scale=True,
222222
use_context_time=True,
223+
use_magnitude_channels=False,
223224
diffusion_sigma_distribution=LogNormalDistribution(mean=-3.0, std=1.0),
224225
diffusion_sigma_data=0.1,
225226
diffusion_dynamic_threshold=0.0,

audio_diffusion_pytorch/modules.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from einops_exts import rearrange_many
99
from torch import Tensor, einsum
1010

11-
from .utils import default, exists, prod, wave_norm, wave_unnorm
11+
from .utils import default, exists, prod
1212

1313
"""
1414
Utils
@@ -785,6 +785,15 @@ def forward(
785785
return x
786786

787787

788+
def get_norm_scale(x: Tensor, quantile: float):
789+
return torch.quantile(x.abs(), quantile, dim=-1, keepdim=True) + 1e-7
790+
791+
792+
def merge_magnitude_channels(x: Tensor):
793+
waveform, magnitude = torch.chunk(x, chunks=2, dim=1)
794+
return torch.sigmoid(waveform) * torch.tanh(magnitude)
795+
796+
788797
"""
789798
UNet
790799
"""
@@ -809,8 +818,8 @@ def __init__(
809818
use_nearest_upsample: bool,
810819
use_skip_scale: bool,
811820
use_context_time: bool,
812-
norm: float = 0.0,
813-
norm_alpha: float = 20.0,
821+
use_magnitude_channels: bool,
822+
norm_quantile: float = 0.0,
814823
out_channels: Optional[int] = None,
815824
context_features: Optional[int] = None,
816825
context_channels: Optional[Sequence[int]] = None,
@@ -824,9 +833,6 @@ def __init__(
824833
use_context_channels = len(context_channels) > 0
825834
context_mapping_features = None
826835

827-
self.use_norm = norm > 0.0
828-
self.norm = norm
829-
self.norm_alpha = norm_alpha
830836
self.num_layers = num_layers
831837
self.use_context_time = use_context_time
832838
self.use_context_features = use_context_features
@@ -841,6 +847,10 @@ def __init__(
841847
self.has_context = has_context
842848
self.channels_ids = [sum(has_context[:i]) for i in range(len(has_context))]
843849

850+
self.use_norm = norm_quantile > 0.0
851+
self.norm_quantile = norm_quantile
852+
self.use_magnitude_channels = use_magnitude_channels
853+
844854
assert (
845855
len(factors) == num_layers
846856
and len(attentions) >= num_layers
@@ -943,7 +953,7 @@ def __init__(
943953

944954
self.to_out = Unpatcher(
945955
in_channels=channels,
946-
out_channels=out_channels,
956+
out_channels=out_channels * (2 if use_magnitude_channels else 1),
947957
blocks=patch_blocks,
948958
factor=patch_factor,
949959
context_mapping_features=context_mapping_features,
@@ -1002,10 +1012,11 @@ def forward(
10021012
# Concat context channels at layer 0 if provided
10031013
channels = self.get_channels(channels_list, layer=0)
10041014
x = torch.cat([x, channels], dim=1) if exists(channels) else x
1015+
# Compute mapping from time and features
10051016
mapping = self.get_mapping(time, features)
1006-
1007-
if self.use_norm:
1008-
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)
1017+
# Compute norm scale
1018+
scale = get_norm_scale(x, self.norm_quantile) if self.use_norm else 1.0
1019+
x = x / scale
10091020

10101021
x = self.to_in(x, mapping)
10111022
skips_list = [x]
@@ -1026,10 +1037,10 @@ def forward(
10261037
x += skips_list.pop()
10271038
x = self.to_out(x, mapping)
10281039

1029-
if self.use_norm:
1030-
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
1040+
if self.use_magnitude_channels:
1041+
x = merge_magnitude_channels(x)
10311042

1032-
return x
1043+
return x * scale
10331044

10341045

10351046
class FixedEmbedding(nn.Module):
@@ -1130,16 +1141,13 @@ def __init__(
11301141
num_blocks: Sequence[int],
11311142
use_noisy: bool = False,
11321143
bottleneck: Optional[Bottleneck] = None,
1133-
norm: float = 0.0,
1134-
norm_alpha: float = 20.0,
1144+
use_magnitude_channels: bool = False,
11351145
):
11361146
super().__init__()
11371147
num_layers = len(multipliers) - 1
11381148
self.bottleneck = bottleneck
11391149
self.use_noisy = use_noisy
1140-
self.use_norm = norm > 0.0
1141-
self.norm = norm
1142-
self.norm_alpha = norm_alpha
1150+
self.use_magnitude_channels = use_magnitude_channels
11431151

11441152
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
11451153

@@ -1181,16 +1189,14 @@ def __init__(
11811189

11821190
self.to_out = Unpatcher(
11831191
in_channels=channels * (use_noisy + 1),
1184-
out_channels=in_channels,
1192+
out_channels=in_channels * (2 if use_magnitude_channels else 1),
11851193
blocks=patch_blocks,
11861194
factor=patch_factor,
11871195
)
11881196

11891197
def encode(
11901198
self, x: Tensor, with_info: bool = False
11911199
) -> Union[Tensor, Tuple[Tensor, Any]]:
1192-
if self.use_norm:
1193-
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)
11941200

11951201
x = self.to_in(x)
11961202
for downsample in self.downsamples:
@@ -1206,12 +1212,14 @@ def decode(self, x: Tensor) -> Tensor:
12061212
if self.use_noisy:
12071213
x = torch.cat([x, torch.randn_like(x)], dim=1)
12081214
x = upsample(x)
1215+
12091216
if self.use_noisy:
12101217
x = torch.cat([x, torch.randn_like(x)], dim=1)
1218+
12111219
x = self.to_out(x)
12121220

1213-
if self.use_norm:
1214-
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
1221+
if self.use_magnitude_channels:
1222+
x = merge_magnitude_channels(x)
12151223

12161224
return x
12171225

audio_diffusion_pytorch/utils.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,3 @@ def downsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
8383

8484
def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
8585
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
86-
87-
88-
def wave_norm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
89-
x = x.clip(-1, 1)
90-
x = 2 * torch.sigmoid(alpha * x) - 1
91-
x = x.clip(-1, 1)
92-
return x * peak
93-
94-
95-
def wave_unnorm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
96-
x = x / peak
97-
x = x.clip(-1, 1)
98-
x = (1.0 / alpha) * torch.log((x + 1) / (1 - x))
99-
x = x.clip(-1, 1)
100-
return x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.56",
6+
version="0.0.57",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)