Skip to content

Commit 532f4cc

Browse files
feat: added standalone model, remove context stuff
1 parent c257ad2 commit 532f4cc

File tree

3 files changed

+141
-136
lines changed

3 files changed

+141
-136
lines changed

audio_diffusion_pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
SigmaSchedule,
99
SpanBySpanComposer,
1010
)
11-
from .models import UNet1d, UNet1d_M
11+
from .model import AudioDiffusionModel, Model1d
12+
from .modules import UNet1d

audio_diffusion_pytorch/model.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from typing import Optional, Sequence
2+
3+
from torch import Tensor, nn
4+
5+
from .diffusion import (
6+
Diffusion,
7+
DiffusionSampler,
8+
KerrasSchedule,
9+
LogNormalSampler,
10+
SigmaSampler,
11+
SigmaSchedule,
12+
)
13+
from .modules import UNet1d
14+
15+
16+
class Model1d(nn.Module):
17+
def __init__(
18+
self,
19+
in_channels: int,
20+
channels: int,
21+
patch_size: int,
22+
resnet_groups: int,
23+
kernel_multiplier_downsample: int,
24+
kernel_sizes_init: Sequence[int],
25+
multipliers: Sequence[int],
26+
factors: Sequence[int],
27+
num_blocks: Sequence[int],
28+
attentions: Sequence[bool],
29+
attention_heads: int,
30+
attention_features: int,
31+
attention_multiplier: int,
32+
use_learned_time_embedding: bool,
33+
use_nearest_upsample: bool,
34+
use_skip_scale: bool,
35+
use_attention_bottleneck: bool,
36+
diffusion_sigma_sampler: SigmaSampler,
37+
diffusion_sigma_data: int,
38+
out_channels: Optional[int] = None,
39+
):
40+
super().__init__()
41+
42+
self.unet = UNet1d(
43+
in_channels=in_channels,
44+
channels=channels,
45+
patch_size=patch_size,
46+
resnet_groups=resnet_groups,
47+
kernel_multiplier_downsample=kernel_multiplier_downsample,
48+
kernel_sizes_init=kernel_sizes_init,
49+
multipliers=multipliers,
50+
factors=factors,
51+
num_blocks=num_blocks,
52+
attentions=attentions,
53+
attention_heads=attention_heads,
54+
attention_features=attention_features,
55+
attention_multiplier=attention_multiplier,
56+
use_learned_time_embedding=use_learned_time_embedding,
57+
use_nearest_upsample=use_nearest_upsample,
58+
use_skip_scale=use_skip_scale,
59+
use_attention_bottleneck=use_attention_bottleneck,
60+
out_channels=out_channels,
61+
)
62+
63+
self.diffusion = Diffusion(
64+
net=self.unet,
65+
sigma_sampler=diffusion_sigma_sampler,
66+
sigma_data=diffusion_sigma_data,
67+
)
68+
69+
def forward(self, x: Tensor) -> Tensor:
70+
return self.diffusion(x)
71+
72+
def sample(
73+
self,
74+
noise: Tensor,
75+
num_steps: int,
76+
sigma_schedule: SigmaSchedule,
77+
s_tmin: float,
78+
s_tmax: float,
79+
s_churn: float,
80+
s_noise: float,
81+
) -> Tensor:
82+
sampler = DiffusionSampler(
83+
diffusion=self.diffusion,
84+
num_steps=num_steps,
85+
sigma_schedule=sigma_schedule,
86+
s_tmin=s_tmin,
87+
s_tmax=s_tmax,
88+
s_churn=s_churn,
89+
s_noise=s_noise,
90+
)
91+
return sampler(noise)
92+
93+
94+
class AudioDiffusionModel(Model1d):
95+
def __init__(self, *args, **kwargs):
96+
default_kwargs = dict(
97+
in_channels=1,
98+
channels=128,
99+
patch_size=16,
100+
multipliers=[1, 2, 4, 4, 4, 4, 4],
101+
factors=[4, 4, 4, 2, 2, 2],
102+
num_blocks=[2, 2, 2, 2, 2, 2],
103+
attentions=[False, False, False, True, True, True],
104+
attention_heads=8,
105+
attention_features=64,
106+
attention_multiplier=2,
107+
resnet_groups=8,
108+
kernel_multiplier_downsample=2,
109+
kernel_sizes_init=[1, 3, 7],
110+
use_nearest_upsample=False,
111+
use_skip_scale=True,
112+
use_attention_bottleneck=True,
113+
use_learned_time_embedding=True,
114+
diffusion_sigma_sampler=LogNormalSampler(mean=-3.0, std=1.0),
115+
diffusion_sigma_data=0.1,
116+
)
117+
super().__init__(*args, **{**default_kwargs, **kwargs})
118+
119+
def sample(self, *args, **kwargs):
120+
default_kwargs = dict(
121+
sigma_schedule=KerrasSchedule(sigma_min=0.002, sigma_max=1),
122+
s_tmin=0,
123+
s_tmax=10,
124+
s_churn=40,
125+
s_noise=1.003,
126+
)
127+
return super().sample(*args, **{**default_kwargs, **kwargs})

0 commit comments

Comments
 (0)