Skip to content

Commit d730653

Browse files
feat: add xunet, refactor all models with xunet
1 parent 5bf1837 commit d730653

File tree

5 files changed

+31
-23
lines changed

5 files changed

+31
-23
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ from audio_diffusion_pytorch import UNet1d
150150
unet = UNet1d(
151151
in_channels=1,
152152
channels=128,
153-
patch_factor=16,
154-
patch_blocks=1,
153+
patch_size=16,
155154
multipliers=[1, 2, 4, 4, 4, 4, 4],
156155
factors=[4, 4, 4, 2, 2, 2],
157156
attentions=[0, 0, 0, 1, 1, 1, 1],

audio_diffusion_pytorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@
3535
DiffusionVocoder1d,
3636
Model1d,
3737
)
38-
from .modules import NumberEmbedder, T5Embedder, UNet1d
38+
from .modules import NumberEmbedder, T5Embedder, UNet1d, XUNet1d

audio_diffusion_pytorch/model.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tqdm import tqdm
1010

1111
from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
12-
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetCFG1d, rand_bool
12+
from .modules import STFT, SinusoidalEmbedding, XUNet1d, rand_bool
1313
from .utils import (
1414
closest_power_2,
1515
default,
@@ -28,18 +28,11 @@
2828

2929

3030
class Model1d(nn.Module):
31-
def __init__(
32-
self, diffusion_type: str, use_classifier_free_guidance: bool = False, **kwargs
33-
):
31+
def __init__(self, unet_type: str = "base", **kwargs):
3432
super().__init__()
3533
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
36-
37-
UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d
38-
self.unet = UNet(**kwargs)
39-
40-
self.diffusion = XDiffusion(
41-
type=diffusion_type, net=self.unet, **diffusion_kwargs
42-
)
34+
self.unet = XUNet1d(type=unet_type, **kwargs)
35+
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)
4336

4437
def forward(self, x: Tensor, **kwargs) -> Tensor:
4538
return self.diffusion(x, **kwargs)
@@ -119,10 +112,10 @@ def __init__(
119112
encoder_channels: int,
120113
encoder_factors: Sequence[int],
121114
encoder_multipliers: Sequence[int],
122-
diffusion_type: str,
123115
encoder_patch_size: int = 1,
124116
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
125117
bottleneck_channels: Optional[int] = None,
118+
unet_type: str = "base",
126119
**kwargs,
127120
):
128121
super().__init__()
@@ -138,13 +131,14 @@ def __init__(
138131
else:
139132
context_channels += [encoder_channels * encoder_multipliers[-1]]
140133

141-
self.unet = UNet1d(
142-
in_channels=in_channels, context_channels=context_channels, **kwargs
134+
self.unet = XUNet1d(
135+
type=unet_type,
136+
in_channels=in_channels,
137+
context_channels=context_channels,
138+
**kwargs,
143139
)
144140

145-
self.diffusion = XDiffusion(
146-
type=diffusion_type, net=self.unet, **diffusion_kwargs
147-
)
141+
self.diffusion = XDiffusion(net=self.unet, **diffusion_kwargs)
148142

149143
self.encoder = Encoder1d(
150144
in_channels=in_channels,
@@ -207,6 +201,7 @@ def __init__(
207201
encoder_patch_size: int = 1,
208202
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
209203
bottleneck_channels: Optional[int] = None,
204+
unet_type: str = "base",
210205
**kwargs,
211206
):
212207
super().__init__()
@@ -233,7 +228,8 @@ def __init__(
233228
use_complex=False, # Magnitude encoding
234229
)
235230

236-
self.unet = UNet1d(
231+
self.unet = XUNet1d(
232+
type=unet_type,
237233
in_channels=in_channels,
238234
context_channels=context_channels,
239235
use_stft=True,
@@ -546,9 +542,9 @@ def __init__(
546542
self.embedding_mask_proba = embedding_mask_proba
547543
default_kwargs = dict(
548544
**get_default_model_kwargs(),
545+
unet_type="cfg",
549546
context_embedding_features=embedding_features,
550547
context_embedding_max_length=embedding_max_length,
551-
use_classifier_free_guidance=True,
552548
)
553549
super().__init__(**{**default_kwargs, **kwargs})
554550

audio_diffusion_pytorch/modules.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,19 @@ def forward(self, *args, **kwargs): # type: ignore
12471247
return UNetCFG1d.forward(self, *args, **kwargs)
12481248

12491249

1250+
def XUNet1d(type: str = "base", **kwargs) -> UNet1d:
1251+
if type == "base":
1252+
return UNet1d(**kwargs)
1253+
elif type == "all":
1254+
return UNetAll1d(**kwargs)
1255+
elif type == "cfg":
1256+
return UNetCFG1d(**kwargs)
1257+
elif type == "ncca":
1258+
return UNetNCCA1d(**kwargs)
1259+
else:
1260+
raise ValueError(f"Unknown XUNet1d type: {type}")
1261+
1262+
12501263
class T5Embedder(nn.Module):
12511264
def __init__(self, model: str = "t5-base", max_length: int = 64):
12521265
super().__init__()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.94",
6+
version="0.0.95",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)