1414 Sampler ,
1515 Schedule ,
1616)
17- from .modules import Encoder1d , ResnetBlock1d , UNet1d
17+ from .modules import Encoder1d , ResnetBlock1d , UNet1d , UNetConditional1d
1818from .utils import default , exists , prod , to_list
1919
20- """ Diffusion Classes (generic for 1d data) """
20+ """
21+ Diffusion Classes (generic for 1d data)
22+ """
2123
2224
2325class Model1d (nn .Module ):
2426 def __init__ (
2527 self ,
26- in_channels : int ,
27- channels : int ,
28- patch_size : int ,
29- kernel_sizes_init : Sequence [int ],
30- multipliers : Sequence [int ],
31- factors : Sequence [int ],
32- num_blocks : Sequence [int ],
33- attentions : Sequence [bool ],
34- attention_heads : int ,
35- attention_features : int ,
36- attention_multiplier : int ,
37- use_attention_bottleneck : bool ,
38- resnet_groups : int ,
39- kernel_multiplier_downsample : int ,
40- use_nearest_upsample : bool ,
41- use_skip_scale : bool ,
4228 diffusion_sigma_distribution : Distribution ,
4329 diffusion_sigma_data : int ,
4430 diffusion_dynamic_threshold : float ,
45- out_channels : Optional [int ] = None ,
46- context_channels : Optional [Sequence [int ]] = None ,
31+ use_classifier_free_guidance : bool = False ,
4732 ** kwargs
4833 ):
4934 super ().__init__ ()
5035
51- self .unet = UNet1d (
52- in_channels = in_channels ,
53- channels = channels ,
54- patch_size = patch_size ,
55- kernel_sizes_init = kernel_sizes_init ,
56- multipliers = multipliers ,
57- factors = factors ,
58- num_blocks = num_blocks ,
59- attentions = attentions ,
60- attention_heads = attention_heads ,
61- attention_features = attention_features ,
62- attention_multiplier = attention_multiplier ,
63- use_attention_bottleneck = use_attention_bottleneck ,
64- resnet_groups = resnet_groups ,
65- kernel_multiplier_downsample = kernel_multiplier_downsample ,
66- use_nearest_upsample = use_nearest_upsample ,
67- use_skip_scale = use_skip_scale ,
68- out_channels = out_channels ,
69- context_channels = context_channels ,
70- ** kwargs
71- )
36+ UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
37+
38+ self .unet = UNet (** kwargs )
7239
7340 self .diffusion = Diffusion (
7441 net = self .unet ,
@@ -114,18 +81,18 @@ def forward(self, x: Tensor, factor: Optional[int] = None, **kwargs) -> Tensor:
11481 # Downsample by picking every `factor` item
11582 downsampled = x [:, :, ::factor ]
11683 # Upsample by interleaving to get context
117- context = torch .repeat_interleave (downsampled , repeats = factor , dim = 2 )
118- return self .diffusion (x , context = [ context ], ** kwargs )
84+ channels = torch .repeat_interleave (downsampled , repeats = factor , dim = 2 )
85+ return self .diffusion (x , channels_list = [ channels ], ** kwargs )
11986
12087 def sample ( # type: ignore
12188 self , undersampled : Tensor , factor : Optional [int ] = None , * args , ** kwargs
12289 ):
12390 # Either user provides factor or we pick the first
12491 factor = default (factor , self .factor [0 ])
125- # Upsample context by interleaving
126- context = torch .repeat_interleave (undersampled , repeats = factor , dim = 2 )
127- noise = torch .randn_like (context )
128- default_kwargs = dict (context = [ context ])
92+ # Upsample channels by interleaving
93+ channels = torch .repeat_interleave (undersampled , repeats = factor , dim = 2 )
94+ noise = torch .randn_like (channels )
95+ default_kwargs = dict (channels_list = [ channels ])
12996 return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
13097
13198
@@ -166,7 +133,7 @@ def __init__(
166133 resnet_groups = resnet_groups ,
167134 kernel_multiplier_downsample = kernel_multiplier_downsample ,
168135 context_channels = [0 ] * encoder_depth + [context_channels ],
169- ** kwargs
136+ ** kwargs ,
170137 )
171138
172139 self .in_channels = in_channels
@@ -190,7 +157,7 @@ def __init__(
190157 extract_channels = [0 ] * (encoder_depth - 1 ) + [encoder_channels ],
191158 )
192159
193- self .to_context = ResnetBlock1d (
160+ self .to_context_channels = ResnetBlock1d (
194161 in_channels = encoder_channels ,
195162 out_channels = context_channels ,
196163 num_groups = resnet_groups ,
@@ -204,8 +171,8 @@ def forward( # type: ignore
204171 else :
205172 latent = self .encode (x )
206173
207- context = self .to_context (latent )
208- loss = self .diffusion (x , context = [ context ], ** kwargs )
174+ channels = self .to_context_channels (latent )
175+ loss = self .diffusion (x , channels_list = [ channels ], ** kwargs )
209176 return (loss , info ) if with_info else loss
210177
211178 def encode (
@@ -224,114 +191,106 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
224191 # Compute noise by inferring shape from latent length
225192 noise = torch .randn (b , self .in_channels , length ).to (latent )
226193 # Compute context form latent
227- context = self .to_context (latent )
228- default_kwargs = dict (context = [ context ])
229- # Decode by sampling while conditioning on latent context
194+ channels = self .to_context_channels (latent )
195+ default_kwargs = dict (channels_list = [ channels ])
196+ # Decode by sampling while conditioning on latent channels
230197 return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
231198
232199
233- """ Audio Diffusion Classes (specific for 1d audio data) """
200+ """
201+ Audio Diffusion Classes (specific for 1d audio data)
202+ """
203+
204+
205+ def get_default_model_kwargs ():
206+ return dict (
207+ channels = 128 ,
208+ patch_size = 16 ,
209+ kernel_sizes_init = [1 , 3 , 7 ],
210+ multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
211+ factors = [4 , 4 , 4 , 2 , 2 , 2 ],
212+ num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
213+ attentions = [False , False , False , True , True , True ],
214+ attention_heads = 8 ,
215+ attention_features = 64 ,
216+ attention_multiplier = 2 ,
217+ use_attention_bottleneck = True ,
218+ resnet_groups = 8 ,
219+ kernel_multiplier_downsample = 2 ,
220+ use_nearest_upsample = False ,
221+ use_skip_scale = True ,
222+ diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
223+ diffusion_sigma_data = 0.1 ,
224+ diffusion_dynamic_threshold = 0.0 ,
225+ )
226+
227+
228+ def get_default_sampling_kwargs ():
229+ return dict (
230+ sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
231+ sampler = ADPM2Sampler (rho = 1.0 ),
232+ )
234233
235234
236235class AudioDiffusionModel (Model1d ):
237- def __init__ (self , * args , ** kwargs ):
238- default_kwargs = dict (
239- channels = 128 ,
240- patch_size = 16 ,
241- kernel_sizes_init = [1 , 3 , 7 ],
242- multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
243- factors = [4 , 4 , 4 , 2 , 2 , 2 ],
244- num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
245- attentions = [False , False , False , True , True , True ],
246- attention_heads = 8 ,
247- attention_features = 64 ,
248- attention_multiplier = 2 ,
249- use_attention_bottleneck = True ,
250- resnet_groups = 8 ,
251- kernel_multiplier_downsample = 2 ,
252- use_nearest_upsample = False ,
253- use_skip_scale = True ,
254- diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
255- diffusion_sigma_data = 0.1 ,
256- diffusion_dynamic_threshold = 0.0 ,
257- )
258-
259- super ().__init__ (* args , ** {** default_kwargs , ** kwargs })
236+ def __init__ (self , ** kwargs ):
237+ super ().__init__ (** {** get_default_model_kwargs (), ** kwargs })
260238
261239 def sample (self , * args , ** kwargs ):
262- default_kwargs = dict (
263- sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
264- sampler = ADPM2Sampler (rho = 1.0 ),
265- )
266- return super ().sample (* args , ** {** default_kwargs , ** kwargs })
240+ return super ().sample (* args , ** {** get_default_sampling_kwargs (), ** kwargs })
267241
268242
269243class AudioDiffusionUpsampler (DiffusionUpsampler1d ):
270- def __init__ (self , in_channels : int , * args , * *kwargs ):
244+ def __init__ (self , in_channels : int , ** kwargs ):
271245 default_kwargs = dict (
246+ ** get_default_model_kwargs (),
272247 in_channels = in_channels ,
273- channels = 128 ,
274- patch_size = 16 ,
275- kernel_sizes_init = [1 , 3 , 7 ],
276- multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
277- factors = [4 , 4 , 4 , 2 , 2 , 2 ],
278- num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
279- attentions = [False , False , False , True , True , True ],
280- attention_heads = 8 ,
281- attention_features = 64 ,
282- attention_multiplier = 2 ,
283- use_attention_bottleneck = True ,
284- resnet_groups = 8 ,
285- kernel_multiplier_downsample = 2 ,
286- use_nearest_upsample = False ,
287- use_skip_scale = True ,
288- diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
289- diffusion_sigma_data = 0.1 ,
290- diffusion_dynamic_threshold = 0.0 ,
291248 context_channels = [in_channels ],
292249 )
293-
294- super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
250+ super ().__init__ (** {** default_kwargs , ** kwargs }) # type: ignore
295251
296252 def sample (self , * args , ** kwargs ):
297- default_kwargs = dict (
298- sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
299- sampler = ADPM2Sampler (rho = 1.0 ),
300- )
301- return super ().sample (* args , ** {** default_kwargs , ** kwargs })
253+ return super ().sample (* args , ** {** get_default_sampling_kwargs (), ** kwargs })
302254
303255
304256class AudioDiffusionAutoencoder (DiffusionAutoencoder1d ):
305257 def __init__ (self , * args , ** kwargs ):
306258 default_kwargs = dict (
307- channels = 128 ,
308- patch_size = 16 ,
309- kernel_sizes_init = [1 , 3 , 7 ],
310- multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
311- factors = [4 , 4 , 4 , 2 , 2 , 2 ],
312- num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
313- attentions = [False , False , False , True , True , True ],
314- attention_heads = 8 ,
315- attention_features = 64 ,
316- attention_multiplier = 2 ,
317- use_attention_bottleneck = True ,
318- resnet_groups = 8 ,
319- kernel_multiplier_downsample = 2 ,
320- use_nearest_upsample = False ,
321- use_skip_scale = True ,
259+ ** get_default_model_kwargs (),
322260 encoder_depth = 4 ,
323261 encoder_channels = 32 ,
324262 context_channels = 512 ,
325- diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
326- diffusion_sigma_data = 0.1 ,
327- diffusion_dynamic_threshold = 0.0 ,
328263 )
329-
330264 super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
331265
332266 def decode (self , * args , ** kwargs ):
267+ return super ().decode (* args , ** {** get_default_sampling_kwargs (), ** kwargs })
268+
269+
270+ class AudioDiffusionConditional (Model1d ):
271+ def __init__ (
272+ self ,
273+ embedding_features : int ,
274+ embedding_max_length : int ,
275+ embedding_mask_proba : float = 0.1 ,
276+ ** kwargs
277+ ):
278+ self .embedding_mask_proba = embedding_mask_proba
333279 default_kwargs = dict (
334- sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
335- sampler = ADPM2Sampler (rho = 1.0 ),
280+ ** get_default_model_kwargs (),
281+ context_embedding_features = embedding_features ,
282+ context_embedding_max_length = embedding_max_length ,
283+ use_classifier_free_guidance = True ,
336284 )
337- return super ().decode (* args , ** {** default_kwargs , ** kwargs })
285+ super ().__init__ (** {** default_kwargs , ** kwargs })
286+
287+ def forward (self , * args , ** kwargs ):
288+ default_kwargs = dict (embedding_mask_proba = self .embedding_mask_proba )
289+ return super ().forward (* args , ** {** default_kwargs , ** kwargs })
290+
291+ def sample (self , * args , ** kwargs ):
292+ default_kwargs = dict (
293+ ** get_default_sampling_kwargs (),
294+ embedding_scale = 5.0 ,
295+ )
296+ return super ().sample (* args , ** {** default_kwargs , ** kwargs })
0 commit comments