1+ from math import prod
12from typing import Optional , Sequence
23
34import torch
1314 Sampler ,
1415 Schedule ,
1516)
16- from .modules import UNet1d
17+ from .modules import Encoder1d , ResnetBlock1d , UNet1d
18+
19+ """ Diffusion Classes (generic for 1d data) """
1720
1821
1922class Model1d (nn .Module ):
@@ -47,8 +50,6 @@ def __init__(
4750 in_channels = in_channels ,
4851 channels = channels ,
4952 patch_size = patch_size ,
50- resnet_groups = resnet_groups ,
51- kernel_multiplier_downsample = kernel_multiplier_downsample ,
5253 kernel_sizes_init = kernel_sizes_init ,
5354 multipliers = multipliers ,
5455 factors = factors ,
@@ -57,9 +58,11 @@ def __init__(
5758 attention_heads = attention_heads ,
5859 attention_features = attention_features ,
5960 attention_multiplier = attention_multiplier ,
61+ use_attention_bottleneck = use_attention_bottleneck ,
62+ resnet_groups = resnet_groups ,
63+ kernel_multiplier_downsample = kernel_multiplier_downsample ,
6064 use_nearest_upsample = use_nearest_upsample ,
6165 use_skip_scale = use_skip_scale ,
62- use_attention_bottleneck = use_attention_bottleneck ,
6366 out_channels = out_channels ,
6467 context_channels = context_channels ,
6568 )
@@ -91,10 +94,110 @@ def sample(
9194 return diffusion_sampler (noise , ** kwargs )
9295
9396
97+ class DiffusionUpsampler1d (Model1d ):
98+ def __init__ (self , factor : int , in_channels : int , * args , ** kwargs ):
99+ self .factor = factor
100+ default_kwargs = dict (
101+ in_channels = in_channels ,
102+ context_channels = [in_channels ],
103+ )
104+ super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
105+
106+ def forward (self , x : Tensor , ** kwargs ) -> Tensor :
107+ # Downsample by picking every `factor` item
108+ downsampled = x [:, :, :: self .factor ]
109+ # Upsample by interleaving to get context
110+ context = torch .repeat_interleave (downsampled , repeats = self .factor , dim = 2 )
111+ return self .diffusion (x , context = [context ], ** kwargs )
112+
113+ def sample (self , undersampled : Tensor , * args , ** kwargs ): # type: ignore
114+ # Upsample context by interleaving
115+ context = torch .repeat_interleave (undersampled , repeats = self .factor , dim = 2 )
116+ noise = torch .randn_like (context )
117+ default_kwargs = dict (context = [context ])
118+ return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
119+
120+
121+ class DiffusionAutoencoder1d (Model1d ):
122+ def __init__ (
123+ self ,
124+ in_channels : int ,
125+ channels : int ,
126+ patch_size : int ,
127+ kernel_sizes_init : Sequence [int ],
128+ multipliers : Sequence [int ],
129+ factors : Sequence [int ],
130+ num_blocks : Sequence [int ],
131+ resnet_groups : int ,
132+ kernel_multiplier_downsample : int ,
133+ encoder_depth : int ,
134+ encoder_channels : int ,
135+ context_channels : int ,
136+ ** kwargs
137+ ):
138+ super ().__init__ (
139+ in_channels = in_channels ,
140+ channels = channels ,
141+ patch_size = patch_size ,
142+ kernel_sizes_init = kernel_sizes_init ,
143+ multipliers = multipliers ,
144+ factors = factors ,
145+ num_blocks = num_blocks ,
146+ resnet_groups = resnet_groups ,
147+ kernel_multiplier_downsample = kernel_multiplier_downsample ,
148+ context_channels = [0 ] * encoder_depth + [context_channels ],
149+ ** kwargs
150+ )
151+
152+ self .in_channels = in_channels
153+ self .encoder_factor = patch_size * prod (factors [0 :encoder_depth ])
154+
155+ self .encoder = Encoder1d (
156+ in_channels = in_channels ,
157+ channels = channels ,
158+ patch_size = patch_size ,
159+ kernel_sizes_init = kernel_sizes_init ,
160+ multipliers = multipliers ,
161+ factors = factors ,
162+ num_blocks = num_blocks ,
163+ resnet_groups = resnet_groups ,
164+ kernel_multiplier_downsample = kernel_multiplier_downsample ,
165+ extract_channels = [0 ] * (encoder_depth - 1 ) + [encoder_channels ],
166+ )
167+
168+ self .to_context = ResnetBlock1d (
169+ in_channels = encoder_channels ,
170+ out_channels = context_channels ,
171+ num_groups = resnet_groups ,
172+ )
173+
174+ def forward (self , x : Tensor , ** kwargs ) -> Tensor :
175+ latent = self .encode (x )
176+ context = self .to_context (latent )
177+ return self .diffusion (x , context = [context ], ** kwargs )
178+
179+ def encode (self , x : Tensor ) -> Tensor :
180+ x = self .encoder (x )[- 1 ]
181+ latent = torch .tanh (x )
182+ return latent
183+
184+ def decode (self , latent : Tensor , ** kwargs ) -> Tensor :
185+ b , length = latent .shape [0 ], latent .shape [2 ] * self .encoder_factor
186+ # Compute noise by inferring shape from latent length
187+ noise = torch .randn (b , self .in_channels , length ).to (latent )
188+ # Compute context form latent
189+ context = self .to_context (latent )
190+ default_kwargs = dict (context = [context ])
191+ # Decode by sampling while conditioning on latent context
192+ return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
193+
194+
195+ """ Audio Diffusion Classes (specific for 1d audio data) """
196+
197+
94198class AudioDiffusionModel (Model1d ):
95199 def __init__ (self , * args , ** kwargs ):
96200 default_kwargs = dict (
97- in_channels = 1 ,
98201 channels = 128 ,
99202 patch_size = 16 ,
100203 kernel_sizes_init = [1 , 3 , 7 ],
@@ -125,10 +228,8 @@ def sample(self, *args, **kwargs):
125228 return super ().sample (* args , ** {** default_kwargs , ** kwargs })
126229
127230
128- class AudioDiffusionUpsampler (Model1d ):
129- def __init__ (self , factor : int , in_channels : int = 1 , * args , ** kwargs ):
130- self .factor = factor
131-
231+ class AudioDiffusionUpsampler (DiffusionUpsampler1d ):
232+ def __init__ (self , in_channels : int , * args , ** kwargs ):
132233 default_kwargs = dict (
133234 in_channels = in_channels ,
134235 channels = 128 ,
@@ -154,19 +255,45 @@ def __init__(self, factor: int, in_channels: int = 1, *args, **kwargs):
154255
155256 super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
156257
157- def forward (self , x : Tensor , ** kwargs ) -> Tensor :
158- # Downsample by picking every `factor` item
159- downsampled = x [:, :, :: self . factor ]
160- # Upsample by interleaving to get context
161- context = torch . repeat_interleave ( downsampled , repeats = self . factor , dim = 2 )
162- return self . diffusion ( x , context = [ context ] , ** kwargs )
258+ def sample (self , * args , ** kwargs ):
259+ default_kwargs = dict (
260+ sigma_schedule = KarrasSchedule ( sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
261+ sampler = ADPM2Sampler ( rho = 1.0 ),
262+ )
263+ return super (). sample ( * args , ** { ** default_kwargs , ** kwargs } )
163264
164- def sample (self , start : Tensor , * args , ** kwargs ): # type: ignore
165- context = torch .repeat_interleave (start , repeats = self .factor , dim = 2 )
166- noise = torch .randn_like (context )
265+
266+ class AudioDiffusionAutoencoder (DiffusionAutoencoder1d ):
267+ def __init__ (self , * args , ** kwargs ):
268+ default_kwargs = dict (
269+ channels = 128 ,
270+ patch_size = 16 ,
271+ kernel_sizes_init = [1 , 3 , 7 ],
272+ multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
273+ factors = [4 , 4 , 4 , 2 , 2 , 2 ],
274+ num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
275+ attentions = [False , False , False , True , True , True ],
276+ attention_heads = 8 ,
277+ attention_features = 64 ,
278+ attention_multiplier = 2 ,
279+ use_attention_bottleneck = True ,
280+ resnet_groups = 8 ,
281+ kernel_multiplier_downsample = 2 ,
282+ use_nearest_upsample = False ,
283+ use_skip_scale = True ,
284+ encoder_depth = 4 ,
285+ encoder_channels = 32 ,
286+ context_channels = 512 ,
287+ diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
288+ diffusion_sigma_data = 0.1 ,
289+ diffusion_dynamic_threshold = 0.0 ,
290+ )
291+
292+ super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
293+
294+ def decode (self , * args , ** kwargs ):
167295 default_kwargs = dict (
168- context = [context ],
169296 sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
170297 sampler = ADPM2Sampler (rho = 1.0 ),
171298 )
172- return super ().sample ( noise , * args , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
299+ return super ().decode ( * args , ** {** default_kwargs , ** kwargs })
0 commit comments