@@ -199,6 +199,9 @@ def __init__(
199199 encoder_multipliers : Sequence [int ],
200200 diffusion_type : str ,
201201 stft_num_fft : int ,
202+ stft_hop_length : int ,
203+ stft_use_complex : bool ,
204+ stft_window_length : Optional [int ] = None ,
202205 encoder_patch_size : int = 1 ,
203206 bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
204207 bottleneck_channels : Optional [int ] = None ,
@@ -209,6 +212,7 @@ def __init__(
209212
210213 encoder_kwargs , kwargs = groupby ("encoder_" , kwargs )
211214 diffusion_kwargs , kwargs = groupby ("diffusion_" , kwargs )
215+ stft_kwargs , kwargs = groupby ("stft_" , kwargs )
212216
213217 # Compute context channels
214218 context_channels = [0 ] * encoder_inject_depth
@@ -218,17 +222,26 @@ def __init__(
218222 context_channels += [encoder_channels * encoder_multipliers [- 1 ]]
219223
220224 self .spectrogram_channels = stft_num_fft // 2 + 1
225+ self .stft_hop_length = stft_hop_length
226+
227+ self .encoder_stft = STFT (
228+ num_fft = stft_num_fft ,
229+ hop_length = stft_hop_length ,
230+ window_length = stft_window_length ,
231+ use_complex = False , # Magnitude encoding
232+ )
221233
222234 self .unet = UNet1d (
223235 in_channels = in_channels ,
224- stft_num_fft = stft_num_fft ,
225236 context_channels = context_channels ,
226237 use_stft = True ,
238+ stft_use_complex = stft_use_complex ,
239+ stft_num_fft = stft_num_fft ,
240+ stft_hop_length = stft_hop_length ,
241+ stft_window_length = stft_window_length ,
227242 ** kwargs ,
228243 )
229244
230- self .stft = self .unet .stft
231-
232245 self .diffusion = XDiffusion (
233246 type = diffusion_type , net = self .unet , ** diffusion_kwargs
234247 )
@@ -251,7 +264,7 @@ def encode(
251264 self , x : Tensor , with_info : bool = False
252265 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
253266 # Extract magnitude and encode
254- magnitude , _ = self .stft .encode (x )
267+ magnitude , _ = self .encoder_stft .encode (x )
255268 magnitude_flat = rearrange (magnitude , "b c f t -> b (c f) t" )
256269 latent , info = self .encoder (magnitude_flat , with_info = True )
257270 # Apply bottlenecks if present
@@ -270,7 +283,7 @@ def forward( # type: ignore
270283 def decode (self , latent : Tensor , ** kwargs ) -> Tensor :
271284 b = latent .shape [0 ]
272285 length = closest_power_2 (
273- self .stft . hop_length * latent .shape [2 ] * self .encoder_downsample_factor
286+ self .stft_hop_length * latent .shape [2 ] * self .encoder_downsample_factor
274287 )
275288 # Compute noise by inferring shape from latent length
276289 noise = torch .randn (b , self .in_channels , length , device = latent .device )
0 commit comments