1212 Sampler ,
1313 Schedule ,
1414)
15- from .modules import UNet1d
15+ from .modules import AutoEncoder1d , UNet1d
16+ from .utils import exists
1617
1718
1819class Model1d (nn .Module ):
@@ -39,9 +40,19 @@ def __init__(
3940 diffusion_sigma_data : int ,
4041 diffusion_dynamic_threshold : float ,
4142 out_channels : Optional [int ] = None ,
43+ use_autoencoder : bool = False ,
44+ autoencoder : Optional [AutoEncoder1d ] = None ,
45+ autoencoder_scale : float = 1.0 ,
4246 ):
4347 super ().__init__ ()
4448
49+ self .use_autoencoder = use_autoencoder
50+
51+ if use_autoencoder :
52+ assert exists (autoencoder )
53+ self .autoencoder_scale = autoencoder_scale
54+ self .autoencoder = autoencoder
55+
4556 self .unet = UNet1d (
4657 in_channels = in_channels ,
4758 channels = channels ,
@@ -71,6 +82,8 @@ def __init__(
7182 )
7283
7384 def forward (self , x : Tensor ) -> Tensor :
85+ if self .use_autoencoder :
86+ x = self .autoencoder_scale * self .autoencoder .encode (x ) # type: ignore
7487 return self .diffusion (x )
7588
7689 def sample (
@@ -82,19 +95,35 @@ def sample(
8295 sigma_schedule = sigma_schedule ,
8396 num_steps = num_steps ,
8497 )
85- return diffusion_sampler (noise )
98+ x = diffusion_sampler (noise )
8699
100+ if self .use_autoencoder :
101+ x = (1.0 / self .autoencoder_scale ) * self .autoencoder .decode (x )
87102
88- class AudioDiffusionModel (Model1d ):
103+ return x
104+
105+
106+ class AudioAutoEncoderModel (AutoEncoder1d ):
89107 def __init__ (self , * args , ** kwargs ):
90108 default_kwargs = dict (
91109 in_channels = 1 ,
110+ bottleneck_channels = 128 ,
92111 channels = 128 ,
93112 patch_size = 16 ,
94- multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
95- factors = [4 , 4 , 4 , 2 , 2 , 2 ],
96- num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
97- attentions = [False , False , False , True , True , True ],
113+ multipliers = [1 , 1 , 1 , 1 , 1 ],
114+ factors = [1 , 4 , 4 , 4 ],
115+ num_blocks = [2 , 2 , 2 , 2 ],
116+ resnet_groups = 8 ,
117+ kernel_multiplier_downsample = 2 ,
118+ loss_kl_weight = 1e-8 ,
119+ )
120+ super ().__init__ (* args , ** {** default_kwargs , ** kwargs })
121+
122+
123+ class AudioDiffusionModel (Model1d ):
124+ def __init__ (self , * args , ** kwargs ):
125+ default_kwargs = dict (
126+ channels = 128 ,
98127 attention_heads = 8 ,
99128 attention_features = 64 ,
100129 attention_multiplier = 2 ,
@@ -106,14 +135,41 @@ def __init__(self, *args, **kwargs):
106135 use_attention_bottleneck = True ,
107136 use_learned_time_embedding = True ,
108137 diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
109- diffusion_sigma_data = 0.1 ,
110- diffusion_dynamic_threshold = 0.95 ,
111138 )
112- super ().__init__ (* args , ** {** default_kwargs , ** kwargs })
139+
140+ model_kwargs = None
141+
142+ if "autoencoder" in kwargs :
143+ sigma_data = 0.2
144+ model_kwargs = dict (
145+ in_channels = 128 ,
146+ patch_size = 1 ,
147+ multipliers = [1 , 4 , 4 , 4 ],
148+ factors = [2 , 2 , 2 ],
149+ num_blocks = [2 , 2 , 2 ],
150+ attentions = [True , True , True ],
151+ diffusion_sigma_data = sigma_data ,
152+ diffusion_dynamic_threshold = 0.0 ,
153+ use_autoencoder = True ,
154+ autoencoder_scale = sigma_data ,
155+ )
156+ else :
157+ model_kwargs = dict (
158+ in_channels = 1 ,
159+ patch_size = 16 ,
160+ multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
161+ factors = [4 , 4 , 4 , 2 , 2 , 2 ],
162+ num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
163+ attentions = [False , False , False , True , True , True ],
164+ diffusion_sigma_data = 0.1 ,
165+ diffusion_dynamic_threshold = 0.95 ,
166+ use_autoencoder = False ,
167+ )
168+ super ().__init__ (* args , ** {** default_kwargs , ** model_kwargs , ** kwargs })
113169
114170 def sample (self , * args , ** kwargs ):
115171 default_kwargs = dict (
116- sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3 , rho = 9.0 ),
117- sampler = ADPM2Sampler (rho = 1 ),
172+ sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
173+ sampler = ADPM2Sampler (rho = 1.0 ),
118174 )
119175 return super ().sample (* args , ** {** default_kwargs , ** kwargs })
0 commit comments