@@ -1277,19 +1277,23 @@ def forward(self, texts: List[str]) -> Tensor:
12771277
12781278
12791279class STFT (nn .Module ):
1280+ """Helper for torch stft and istft"""
1281+
12801282 def __init__ (
12811283 self ,
12821284 num_fft : int = 1023 ,
1283- hop_length : Optional [ int ] = None ,
1285+ hop_length : int = 256 ,
12841286 window_length : Optional [int ] = None ,
12851287 length : Optional [int ] = None ,
1288+ use_complex : bool = False ,
12861289 ):
12871290 super ().__init__ ()
12881291 self .num_fft = num_fft
12891292 self .hop_length = default (hop_length , floor (num_fft // 4 ))
12901293 self .window_length = default (window_length , num_fft )
12911294 self .length = length
12921295 self .register_buffer ("window" , torch .hann_window (self .window_length ))
1296+ self .use_complex = use_complex
12931297
12941298 def encode (self , wave : Tensor ) -> Tuple [Tensor , Tensor ]:
12951299 b = wave .shape [0 ]
@@ -1302,43 +1306,54 @@ def encode(self, wave: Tensor) -> Tuple[Tensor, Tensor]:
13021306 win_length = self .window_length ,
13031307 window = self .window , # type: ignore
13041308 return_complex = True ,
1309+ normalized = True ,
13051310 )
13061311
1307- mag = torch .sqrt (torch .clamp ((stft .real ** 2 ) + (stft .imag ** 2 ), min = 1e-8 ))
1308- mag = rearrange (mag , "(b c) f l -> b c f l" , b = b )
1312+ if self .use_complex :
1313+ # Returns real and imaginary
1314+ stft_a , stft_b = stft .real , stft .imag
1315+ else :
1316+ # Returns magnitude and phase matrices
1317+ magnitude , phase = torch .abs (stft ), torch .angle (stft )
1318+ stft_a , stft_b = magnitude , phase
13091319
1310- phase = torch .angle (stft )
1311- phase = rearrange (phase , "(b c) f l -> b c f l" , b = b )
1312- return mag , phase
1320+ return rearrange_many ((stft_a , stft_b ), "(b c) f l -> b c f l" , b = b )
13131321
1314- def decode (self , magnitude : Tensor , phase : Tensor ) -> Tensor :
1315- b , l = magnitude .shape [0 ], magnitude .shape [- 1 ] # noqa
1316- assert magnitude .shape == phase .shape , "magnitude and phase must be same shape"
1317- real = rearrange (magnitude * torch .cos (phase ), "b c f l -> (b c) f l" )
1318- imag = rearrange (magnitude * torch .sin (phase ), "b c f l -> (b c) f l" )
1319- stft = torch .stack ([real , imag ], dim = - 1 )
1322+ def decode (self , stft_a : Tensor , stft_b : Tensor ) -> Tensor :
1323+ b , l = stft_a .shape [0 ], stft_a .shape [- 1 ] # noqa
13201324 length = closest_power_2 (l * self .hop_length )
13211325
1326+ stft_a , stft_b = rearrange_many ((stft_a , stft_b ), "b c f l -> (b c) f l" )
1327+
1328+ if self .use_complex :
1329+ real , imag = stft_a , stft_b
1330+ else :
1331+ magnitude , phase = stft_a , stft_b
1332+ real , imag = magnitude * torch .cos (phase ), magnitude * torch .sin (phase )
1333+
1334+ stft = torch .stack ([real , imag ], dim = - 1 )
1335+
13221336 wave = torch .istft (
13231337 stft ,
13241338 n_fft = self .num_fft ,
13251339 hop_length = self .hop_length ,
13261340 win_length = self .window_length ,
13271341 window = self .window , # type: ignore
13281342 length = default (self .length , length ),
1343+ normalized = True ,
13291344 )
1330- wave = rearrange ( wave , "(b c) t -> b c t" , b = b )
1331- return wave
1345+
1346+ return rearrange ( wave , "(b c) t -> b c t" , b = b )
13321347
13331348 def encode1d (
13341349 self , wave : Tensor , stacked : bool = True
13351350 ) -> Union [Tensor , Tuple [Tensor , Tensor ]]:
1336- magnitude , phase = self .encode (wave )
1337- magnitude , phase = rearrange_many ((magnitude , phase ), "b c f l -> b (c f) l" )
1338- return torch .cat ((magnitude , phase ), dim = 1 ) if stacked else (magnitude , phase )
1351+ stft_a , stft_b = self .encode (wave )
1352+ stft_a , stft_b = rearrange_many ((stft_a , stft_b ), "b c f l -> b (c f) l" )
1353+ return torch .cat ((stft_a , stft_b ), dim = 1 ) if stacked else (stft_a , stft_b )
13391354
1340- def decode1d (self , magnitude_and_phase : Tensor ) -> Tensor :
1355+ def decode1d (self , stft_pair : Tensor ) -> Tensor :
13411356 f = self .num_fft // 2 + 1
1342- magnitude , phase = magnitude_and_phase .chunk (chunks = 2 , dim = 1 )
1343- mag , phase = rearrange_many ((magnitude , phase ), "b (c f) l -> b c f l" , f = f )
1344- return self .decode (mag , phase )
1357+ stft_a , stft_b = stft_pair .chunk (chunks = 2 , dim = 1 )
1358+ stft_a , stft_b = rearrange_many ((stft_a , stft_b ), "b (c f) l -> b c f l" , f = f )
1359+ return self .decode (stft_a , stft_b )
0 commit comments