22from typing import Any , Optional , Sequence , Tuple , Union
33
44import torch
5+ from audio_encoders_pytorch import Bottleneck , Encoder1d
56from einops import rearrange
67from torch import Tensor , nn
78
1617 VKDiffusion ,
1718 VSampler ,
1819)
19- from .modules import (
20- STFT ,
21- Bottleneck ,
22- MultiEncoder1d ,
23- SinusoidalEmbedding ,
24- UNet1d ,
25- UNetConditional1d ,
20+ from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetConditional1d
21+ from .utils import (
22+ closest_power_2 ,
23+ default ,
24+ downsample ,
25+ exists ,
26+ groupby ,
27+ prefix_dict ,
28+ prod ,
29+ to_list ,
30+ upsample ,
2631)
27- from .utils import default , downsample , exists , groupby_kwargs_prefix , to_list , upsample
2832
2933"""
3034Diffusion Classes (generic for 1d data)
@@ -36,7 +40,7 @@ def __init__(
3640 self , diffusion_type : str , use_classifier_free_guidance : bool = False , ** kwargs
3741 ):
3842 super ().__init__ ()
39- diffusion_kwargs , kwargs = groupby_kwargs_prefix ("diffusion_" , kwargs )
43+ diffusion_kwargs , kwargs = groupby ("diffusion_" , kwargs )
4044
4145 UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
4246 self .unet = UNet (** kwargs )
@@ -149,31 +153,25 @@ def __init__(
149153 resnet_groups : int ,
150154 kernel_multiplier_downsample : int ,
151155 encoder_depth : int ,
152- encoder_channels : int ,
153- bottleneck : Optional [Bottleneck ] = None ,
154156 encoder_num_blocks : Optional [Sequence [int ]] = None ,
155- encoder_out_layers : int = 0 ,
157+ bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
158+ bottleneck_channels : Optional [int ] = None ,
159+ use_stft : bool = False ,
156160 ** kwargs ,
157161 ):
158162 self .in_channels = in_channels
159163 encoder_num_blocks = default (encoder_num_blocks , num_blocks )
160164 assert_message = "The number of encoder_num_blocks must match encoder_depth"
161165 assert len (encoder_num_blocks ) >= encoder_depth , assert_message
166+ assert patch_blocks == 1 , "patch_blocks != 1 not supported"
167+ assert not use_stft , "use_stft not supported"
168+ self .factor = patch_factor * prod (factors [0 :encoder_depth ])
162169
163- multiencoder = MultiEncoder1d (
164- in_channels = in_channels ,
165- channels = channels ,
166- patch_blocks = patch_blocks ,
167- patch_factor = patch_factor ,
168- num_layers = encoder_depth ,
169- num_layers_out = encoder_out_layers ,
170- latent_channels = encoder_channels ,
171- multipliers = multipliers ,
172- factors = factors ,
173- num_blocks = encoder_num_blocks ,
174- kernel_multiplier_downsample = kernel_multiplier_downsample ,
175- resnet_groups = resnet_groups ,
176- )
170+ context_channels = [0 ] * encoder_depth
171+ if exists (bottleneck_channels ):
172+ context_channels += [bottleneck_channels ]
173+ else :
174+ context_channels += [channels * multipliers [encoder_depth ]]
177175
178176 super ().__init__ (
179177 in_channels = in_channels ,
@@ -185,89 +183,81 @@ def __init__(
185183 num_blocks = num_blocks ,
186184 resnet_groups = resnet_groups ,
187185 kernel_multiplier_downsample = kernel_multiplier_downsample ,
188- context_channels = multiencoder . channels_list ,
186+ context_channels = context_channels ,
189187 ** kwargs ,
190188 )
191189
192- self .bottleneck = bottleneck
193- self .multiencoder = multiencoder
190+ self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
191+ self .encoder = Encoder1d (
192+ in_channels = in_channels ,
193+ channels = channels ,
194+ patch_size = patch_factor ,
195+ multipliers = multipliers [0 : encoder_depth + 1 ],
196+ factors = factors [0 :encoder_depth ],
197+ num_blocks = encoder_num_blocks [0 :encoder_depth ],
198+ resnet_groups = resnet_groups ,
199+ out_channels = bottleneck_channels ,
200+ )
201+
202+ def encode (
203+ self , x : Tensor , with_info : bool = False
204+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
205+ latent , info = self .encoder (x , with_info = True )
206+ for bottleneck in self .bottlenecks :
207+ x , info_bottleneck = bottleneck (x , with_info = True )
208+ info = {** info , ** prefix_dict ("bottleneck_" , info_bottleneck )}
209+ return (latent , info ) if with_info else latent
194210
195211 def forward ( # type: ignore
196212 self , x : Tensor , with_info : bool = False , ** kwargs
197213 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
198- if with_info :
199- latent , info = self .encode (x , with_info = True )
200- else :
201- latent = self .encode (x )
202-
203- channels_list = self .multiencoder .decode (latent )
204- loss = self .diffusion (x , channels_list = channels_list , ** kwargs )
214+ latent , info = self .encode (x , with_info = True )
215+ loss = self .diffusion (x , channels_list = [latent ], ** kwargs )
205216 return (loss , info ) if with_info else loss
206217
207- def encode (
208- self , x : Tensor , with_info : bool = False
209- ) -> Union [Tensor , Tuple [Tensor , Any ]]:
210- latent = self .multiencoder .encode (x )
211- latent = torch .tanh (latent )
212- # Apply bottleneck if provided (e.g. quantization module)
213- if exists (self .bottleneck ):
214- latent , info = self .bottleneck (latent )
215- return (latent , info ) if with_info else latent
216- return latent
217-
218218 def decode (self , latent : Tensor , ** kwargs ) -> Tensor :
219- b , length = latent .shape [0 ], latent .shape [2 ] * self .multiencoder . factor
219+ b , length = latent .shape [0 ], latent .shape [2 ] * self .factor
220220 # Compute noise by inferring shape from latent length
221221 noise = torch .randn (b , self .in_channels , length ).to (latent )
222222 # Compute context form latent
223- channels_list = self .multiencoder .decode (latent )
224- default_kwargs = dict (channels_list = channels_list )
223+ default_kwargs = dict (channels_list = [latent ])
225224 # Decode by sampling while conditioning on latent channels
226225 return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
227226
228227
229228class DiffusionVocoder1d (Model1d ):
230- def __init__ (
231- self ,
232- in_channels : int ,
233- vocoder_num_fft : int ,
234- ** kwargs ,
235- ):
236- self .frequency_channels = vocoder_num_fft // 2 + 1
237- spectrogram_channels = in_channels * self .frequency_channels
238-
239- vocoder_kwargs , kwargs = groupby_kwargs_prefix ("vocoder_" , kwargs )
229+ def __init__ (self , in_channels : int , stft_num_fft : int , ** kwargs ):
230+ self .stft_num_fft = stft_num_fft
231+ spectrogram_channels = stft_num_fft // 2 + 1
240232 default_kwargs = dict (
241- in_channels = spectrogram_channels , context_channels = [spectrogram_channels ]
233+ in_channels = in_channels ,
234+ use_stft = True ,
235+ stft_num_fft = stft_num_fft ,
236+ context_channels = [in_channels * spectrogram_channels ],
242237 )
243-
244238 super ().__init__ (** {** default_kwargs , ** kwargs }) # type: ignore
245- self .stft = STFT (num_fft = vocoder_num_fft , ** vocoder_kwargs )
246239
247240 def forward (self , x : Tensor , ** kwargs ) -> Tensor :
248- # Get magnitude and phase of true wave
249- magnitude , phase = self .stft .encode (x )
241+ # Get magnitude spectrogram from true wave
242+ magnitude , _ = self . unet .stft .encode (x )
250243 magnitude = rearrange (magnitude , "b c f t -> b (c f) t" )
251- phase = rearrange (phase , "b c f t -> b (c f) t" )
252- # Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
253- return self .diffusion (phase / pi , channels_list = [magnitude ], ** kwargs )
244+ # Get diffusion loss while conditioning on magnitude
245+ return self .diffusion (x , channels_list = [magnitude ], ** kwargs )
254246
255247 def sample (self , spectrogram : Tensor , ** kwargs ): # type: ignore
256- b , c , f , t , device = * spectrogram .shape , spectrogram .device
248+ b , c , _ , t , device = * spectrogram .shape , spectrogram .device
257249 magnitude = rearrange (spectrogram , "b c f t -> b (c f) t" )
258- noise = torch .randn ((b , c * f , t ), device = device )
250+ timesteps = closest_power_2 (self .unet .stft .hop_length * t )
251+ noise = torch .randn ((b , c , timesteps ), device = device )
259252 default_kwargs = dict (channels_list = [magnitude ])
260- phase = super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
261- phase = rearrange (phase , "b (c f) t -> b c f t" , c = c )
262- wave = self .stft .decode (spectrogram , phase * pi )
263- return wave
253+ return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
264254
265255
266256class DiffusionUpphaser1d (DiffusionUpsampler1d ):
267257 def __init__ (self , ** kwargs ):
268- vocoder_kwargs , kwargs = groupby_kwargs_prefix ( "vocoder_ " , kwargs )
258+ stft_kwargs , kwargs = groupby ( "stft_ " , kwargs )
269259 super ().__init__ (** kwargs )
270- self .stft = STFT (** vocoder_kwargs )
260+ self .stft = STFT (** stft_kwargs )
271261
272262 def random_rephase (self , x : Tensor ) -> Tensor :
273263 magnitude , phase = self .stft .encode (x )
@@ -305,7 +295,6 @@ def get_default_model_kwargs():
305295 use_nearest_upsample = False ,
306296 use_skip_scale = True ,
307297 use_context_time = True ,
308- use_magnitude_channels = False ,
309298 diffusion_type = "v" ,
310299 diffusion_sigma_distribution = UniformDistribution (),
311300 )
@@ -380,12 +369,13 @@ class AudioDiffusionVocoder(DiffusionVocoder1d):
380369 def __init__ (self , in_channels : int , ** kwargs ):
381370 default_kwargs = dict (
382371 in_channels = in_channels ,
383- vocoder_num_fft = 1023 ,
384- channels = 32 ,
372+ stft_num_fft = 1023 ,
373+ stft_hop_length = 256 ,
374+ channels = 64 ,
385375 patch_blocks = 1 ,
386376 patch_factor = 1 ,
387- multipliers = [64 , 32 , 16 , 8 , 4 , 2 , 1 ],
388- factors = [1 , 1 , 1 , 1 , 1 , 1 ],
377+ multipliers = [48 , 32 , 16 , 8 , 8 , 8 , 8 ],
378+ factors = [2 , 2 , 2 , 1 , 1 , 1 ],
389379 num_blocks = [1 , 1 , 1 , 1 , 1 , 1 ],
390380 attentions = [0 , 0 , 0 , 1 , 1 , 1 ],
391381 attention_heads = 8 ,
0 commit comments