1919)
2020from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetConditional1d
2121from .utils import (
22- closest_power_2 ,
2322 default ,
2423 downsample ,
2524 exists ,
@@ -204,7 +203,7 @@ def encode(
204203 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
205204 latent , info = self .encoder (x , with_info = True )
206205 for bottleneck in self .bottlenecks :
207- x , info_bottleneck = bottleneck (x , with_info = True )
206+ latent , info_bottleneck = bottleneck (latent , with_info = True )
208207 info = {** info , ** prefix_dict ("bottleneck_" , info_bottleneck )}
209208 return (latent , info ) if with_info else latent
210209
@@ -226,31 +225,43 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
226225
227226
228227class DiffusionVocoder1d (Model1d ):
229- def __init__ (self , in_channels : int , stft_num_fft : int , ** kwargs ):
230- self .stft_num_fft = stft_num_fft
231- spectrogram_channels = stft_num_fft // 2 + 1
228+ def __init__ (
229+ self ,
230+ in_channels : int ,
231+ stft_num_fft : int ,
232+ ** kwargs ,
233+ ):
234+ self .frequency_channels = stft_num_fft // 2 + 1
235+ spectrogram_channels = in_channels * self .frequency_channels
236+
237+ stft_kwargs , kwargs = groupby ("stft_" , kwargs )
232238 default_kwargs = dict (
233- in_channels = in_channels ,
234- use_stft = True ,
235- stft_num_fft = stft_num_fft ,
236- context_channels = [in_channels * spectrogram_channels ],
239+ in_channels = spectrogram_channels , context_channels = [spectrogram_channels ]
237240 )
241+
238242 super ().__init__ (** {** default_kwargs , ** kwargs }) # type: ignore
243+ self .stft = STFT (num_fft = stft_num_fft , ** stft_kwargs )
239244
240- def forward (self , x : Tensor , ** kwargs ) -> Tensor :
241- # Get magnitude spectrogram from true wave
242- magnitude , _ = self .unet .stft .encode (x )
243- magnitude = rearrange (magnitude , "b c f t -> b (c f) t" )
244- # Get diffusion loss while conditioning on magnitude
245- return self .diffusion (x , channels_list = [magnitude ], ** kwargs )
245+ def forward_wave (self , x : Tensor , ** kwargs ) -> Tensor :
246+ # Get magnitude and phase of true wave
247+ magnitude , phase = self .stft .encode (x )
248+ return self (magnitude , phase , ** kwargs )
246249
247- def sample (self , spectrogram : Tensor , ** kwargs ): # type: ignore
248- b , c , _ , t , device = * spectrogram .shape , spectrogram .device
249- magnitude = rearrange (spectrogram , "b c f t -> b (c f) t" )
250- timesteps = closest_power_2 (self .unet .stft .hop_length * t )
251- noise = torch .randn ((b , c , timesteps ), device = device )
252- default_kwargs = dict (channels_list = [magnitude ])
253- return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
250+ def forward (self , magnitude : Tensor , phase : Tensor , ** kwargs ) -> Tensor : # type: ignore # noqa
251+ magnitude = rearrange (magnitude , "b c f t -> b (c f) t" )
252+ phase = rearrange (phase , "b c f t -> b (c f) t" )
253+ # Get diffusion phase loss while conditioning on magnitude (/pi [-1,1] range)
254+ return self .diffusion (phase / pi , channels_list = [magnitude ], ** kwargs )
255+
256+ def sample (self , magnitude : Tensor , ** kwargs ): # type: ignore
257+ b , c , f , t , device = * magnitude .shape , magnitude .device
258+ magnitude_flat = rearrange (magnitude , "b c f t -> b (c f) t" )
259+ noise = torch .randn ((b , c * f , t ), device = device )
260+ default_kwargs = dict (channels_list = [magnitude_flat ])
261+ phase_flat = super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
262+ phase = rearrange (phase_flat , "b (c f) t -> b c f t" , c = c )
263+ wave = self .stft .decode (magnitude , phase * pi )
264+ return wave
254265
255266
256267class DiffusionUpphaser1d (DiffusionUpsampler1d ):
@@ -371,13 +382,13 @@ def __init__(self, in_channels: int, **kwargs):
371382 in_channels = in_channels ,
372383 stft_num_fft = 1023 ,
373384 stft_hop_length = 256 ,
374- channels = 64 ,
385+ channels = 512 ,
375386 patch_blocks = 1 ,
376387 patch_factor = 1 ,
377- multipliers = [48 , 32 , 16 , 8 , 8 , 8 , 8 ],
378- factors = [2 , 2 , 2 , 1 , 1 , 1 ],
379- num_blocks = [1 , 1 , 1 , 1 , 1 , 1 ],
380- attentions = [0 , 0 , 0 , 1 , 1 , 1 ],
388+ multipliers = [3 , 2 , 1 , 1 , 1 , 1 , 1 , 1 ],
389+ factors = [1 , 2 , 2 , 2 , 2 , 2 , 2 ],
390+ num_blocks = [1 , 1 , 1 , 1 , 1 , 1 , 1 ],
391+ attentions = [0 , 0 , 0 , 0 , 1 , 1 , 1 ],
381392 attention_heads = 8 ,
382393 attention_features = 64 ,
383394 attention_multiplier = 2 ,
0 commit comments