Skip to content

Commit 13cddd6

Browse files
feat: update to sigmoid norm
1 parent 749fb20 commit 13cddd6

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

audio_diffusion_pytorch/modules.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,7 @@ def __init__(
810810
use_skip_scale: bool,
811811
use_context_time: bool,
812812
norm: float = 0.0,
813+
norm_alpha: float = 20.0,
813814
out_channels: Optional[int] = None,
814815
context_features: Optional[int] = None,
815816
context_channels: Optional[Sequence[int]] = None,
@@ -823,8 +824,9 @@ def __init__(
823824
use_context_channels = len(context_channels) > 0
824825
context_mapping_features = None
825826

826-
self.norm = norm
827827
self.use_norm = norm > 0.0
828+
self.norm = norm
829+
self.norm_alpha = norm_alpha
828830
self.num_layers = num_layers
829831
self.use_context_time = use_context_time
830832
self.use_context_features = use_context_features
@@ -1003,7 +1005,7 @@ def forward(
10031005
mapping = self.get_mapping(time, features)
10041006

10051007
if self.use_norm:
1006-
x = wave_norm(x, peak=self.norm)
1008+
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)
10071009

10081010
x = self.to_in(x, mapping)
10091011
skips_list = [x]
@@ -1025,7 +1027,7 @@ def forward(
10251027
x = self.to_out(x, mapping)
10261028

10271029
if self.use_norm:
1028-
x = wave_unnorm(x, peak=self.norm)
1030+
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
10291031

10301032
return x
10311033

@@ -1129,13 +1131,15 @@ def __init__(
11291131
use_noisy: bool = False,
11301132
bottleneck: Optional[Bottleneck] = None,
11311133
norm: float = 0.0,
1134+
norm_alpha: float = 20.0,
11321135
):
11331136
super().__init__()
11341137
num_layers = len(multipliers) - 1
11351138
self.bottleneck = bottleneck
11361139
self.use_noisy = use_noisy
11371140
self.use_norm = norm > 0.0
11381141
self.norm = norm
1142+
self.norm_alpha = norm_alpha
11391143

11401144
assert len(factors) >= num_layers and len(num_blocks) >= num_layers
11411145

@@ -1186,7 +1190,7 @@ def encode(
11861190
self, x: Tensor, with_info: bool = False
11871191
) -> Union[Tensor, Tuple[Tensor, Any]]:
11881192
if self.use_norm:
1189-
x = wave_norm(x, peak=self.norm)
1193+
x = wave_norm(x, peak=self.norm, alpha=self.norm_alpha)
11901194

11911195
x = self.to_in(x)
11921196
for downsample in self.downsamples:
@@ -1207,7 +1211,7 @@ def decode(self, x: Tensor) -> Tensor:
12071211
x = self.to_out(x)
12081212

12091213
if self.use_norm:
1210-
x = wave_unnorm(x, peak=self.norm)
1214+
x = wave_unnorm(x, peak=self.norm, alpha=self.norm_alpha)
12111215

12121216
return x
12131217

audio_diffusion_pytorch/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ def upsample(waveforms: Tensor, factor: int, **kwargs) -> Tensor:
8585
return resample(waveforms, factor_in=1, factor_out=factor, **kwargs)
8686

8787

88-
def wave_norm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
89-
mu = 2 ** bits
88+
def wave_norm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
89+
x = x.clip(-1, 1)
90+
x = torch.sigmoid(alpha * x)
9091
x = x.clip(-1, 1)
91-
x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / math.log1p(mu)
9292
return x * peak
9393

9494

95-
def wave_unnorm(x: Tensor, bits: int = 24, peak: float = 0.5) -> Tensor:
96-
x = (x / peak).clip(-1, 1)
97-
mu = 2 ** bits
98-
x = torch.sign(x) * (torch.exp(torch.abs(x) * math.log1p(mu)) - 1) / mu
95+
def wave_unnorm(x: Tensor, peak: float = 0.5, alpha: float = 20.0) -> Tensor:
96+
x = x / peak
97+
x = x.clip(-1, 1)
98+
x = (1.0 / alpha) * torch.log(x / (1 - x))
99+
x = x.clip(-1, 1)
99100
return x

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.54",
6+
version="0.0.55",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)