@@ -1687,3 +1687,99 @@ def decode(self, latent: Tensor) -> List[Tensor]:
16871687 x = self .to_out (x )
16881688 channels_list += [x ]
16891689 return channels_list [::- 1 ]
1690+
1691+
1692+ class STFT (nn .Module ):
1693+ def __init__ (
1694+ self ,
1695+ length : int ,
1696+ num_fft : int = 1024 ,
1697+ hop_length : int = 256 ,
1698+ window_length : int = 1024 ,
1699+ ):
1700+ super ().__init__ ()
1701+ self .num_fft = num_fft
1702+ self .hop_length = hop_length
1703+ self .window_length = window_length
1704+ self .length = length
1705+ self .register_buffer ("window" , torch .hann_window (window_length ))
1706+
1707+ def encode (self , wave : Tensor ) -> Tuple [Tensor , Tensor ]:
1708+ b = wave .shape [0 ]
1709+ wave = rearrange (wave , "b c t -> (b c) t" )
1710+
1711+ stft = torch .stft (
1712+ wave ,
1713+ n_fft = self .num_fft ,
1714+ hop_length = self .hop_length ,
1715+ win_length = self .window_length ,
1716+ window = self .window , # type: ignore
1717+ return_complex = True ,
1718+ )
1719+
1720+ mag = torch .sqrt (torch .clamp ((stft .real ** 2 ) + (stft .imag ** 2 ), min = 1e-8 ))
1721+ mag = rearrange (mag , "(b c) f l -> b c f l" , b = b )
1722+
1723+ phase = torch .angle (stft )
1724+ phase = rearrange (phase , "(b c) f l -> b c f l" , b = b )
1725+ return mag , phase
1726+
1727+ def decode (self , magnitude : Tensor , phase : Tensor ) -> Tensor :
1728+ b = magnitude .shape [0 ]
1729+ assert magnitude .shape == phase .shape , "magnitude and phase must be same shape"
1730+ real = rearrange (magnitude * torch .cos (phase ), "b c f l -> (b c) f l" )
1731+ imag = rearrange (magnitude * torch .sin (phase ), "b c f l -> (b c) f l" )
1732+ stft = torch .stack ([real , imag ], dim = - 1 )
1733+
1734+ wave = torch .istft (
1735+ stft ,
1736+ n_fft = self .num_fft ,
1737+ hop_length = self .hop_length ,
1738+ win_length = self .window_length ,
1739+ window = self .window , # type: ignore
1740+ length = self .length ,
1741+ )
1742+ wave = rearrange (wave , "(b c) t -> b c t" , b = b )
1743+ return wave
1744+
1745+
1746+ class STFTAutoEncoder1d (AutoEncoder1d ):
1747+ def __init__ (
1748+ self ,
1749+ in_channels : int ,
1750+ length : int ,
1751+ num_fft : int = 1024 ,
1752+ hop_length : int = 256 ,
1753+ window_length : int = 1024 ,
1754+ ** kwargs ,
1755+ ):
1756+ self .frequency_channels = num_fft // 2 + 1
1757+
1758+ super ().__init__ (
1759+ in_channels = in_channels * self .frequency_channels ,
1760+ out_channels = in_channels * self .frequency_channels * 2 ,
1761+ patch_blocks = 1 ,
1762+ patch_factor = 1 ,
1763+ ** kwargs ,
1764+ )
1765+
1766+ self .stft = STFT (
1767+ num_fft = num_fft ,
1768+ hop_length = hop_length ,
1769+ window_length = window_length ,
1770+ length = length ,
1771+ )
1772+
1773+ def encode (
1774+ self , wave : Tensor , with_info : bool = False
1775+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1776+ magnitude , phase = self .stft .encode (wave )
1777+ log_magnitude = rearrange (torch .log (magnitude ), "b c f t -> b (c f) t" )
1778+ return super ().encode (log_magnitude , with_info )
1779+
1780+ def decode (self , z : Tensor ) -> Tensor :
1781+ f = self .frequency_channels
1782+ stft = super ().decode (z )
1783+ stft = rearrange (stft , "b (c f i) t -> b (c i) f t" , i = 2 , f = f )
1784+ log_magnitude , phase = stft .chunk (chunks = 2 , dim = 1 )
1785+ return self .stft .decode (magnitude = torch .exp (log_magnitude ), phase = phase )
0 commit comments