77from torch import Tensor , nn
88
99from .diffusion import LinearSchedule , UniformDistribution , VSampler , XDiffusion
10- from .modules import STFT , Conv1d , SinusoidalEmbedding , UNet1d , UNetConditional1d
10+ from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetConditional1d
1111from .utils import (
1212 default ,
1313 downsample ,
@@ -153,13 +153,6 @@ def __init__(
153153 ** encoder_kwargs ,
154154 )
155155
156- if exists (bottleneck_channels ):
157- self .to_bottleneck = Conv1d (
158- in_channels = encoder_channels * encoder_multipliers [- 1 ],
159- out_channels = bottleneck_channels ,
160- kernel_size = 1 ,
161- )
162-
163156 self .encoder_downsample_factor = encoder_patch_size * prod (encoder_factors )
164157 self .bottleneck_channels = bottleneck_channels
165158 self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
@@ -168,9 +161,6 @@ def encode(
168161 self , x : Tensor , with_info : bool = False
169162 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
170163 latent , info = self .encoder (x , with_info = True )
171- # Convert latent channels
172- if exists (self .bottleneck_channels ):
173- latent = self .to_bottleneck (latent )
174164 # Apply bottlenecks if present
175165 for bottleneck in self .bottlenecks :
176166 latent , info_bottleneck = bottleneck (latent , with_info = True )
0 commit comments