11import math
22from math import pi
3- from typing import Any , List , Optional , Sequence , Tuple , Union
3+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
44
55import torch
66import torch .nn as nn
99from einops_exts import rearrange_many
1010from torch import Tensor , einsum
1111
12- from .utils import default , exists , prod
12+ from .utils import default , exists , prod , to_list
1313
1414"""
1515Utils
@@ -1341,6 +1341,15 @@ def forward(
13411341 return (out , dict (loss = loss , mean = mean , logvar = logvar )) if with_info else out
13421342
13431343
1344+ class Tanh (Bottleneck ):
1345+ def forward (
1346+ self , x : Tensor , with_info : bool = False
1347+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1348+ x = torch .tanh (x )
1349+ info : Dict = dict ()
1350+ return (x , info ) if with_info else x
1351+
1352+
13441353class AutoEncoder1d (nn .Module ):
13451354 def __init__ (
13461355 self ,
@@ -1353,12 +1362,12 @@ def __init__(
13531362 factors : Sequence [int ],
13541363 num_blocks : Sequence [int ],
13551364 use_noisy : bool = False ,
1356- bottleneck : Optional [Bottleneck ] = None ,
1365+ bottleneck : Union [Bottleneck , List [ Bottleneck ]] = [] ,
13571366 use_magnitude_channels : bool = False ,
13581367 ):
13591368 super ().__init__ ()
13601369 num_layers = len (multipliers ) - 1
1361- self .bottleneck = bottleneck
1370+ self .bottlenecks = to_list ( bottleneck )
13621371 self .use_noisy = use_noisy
13631372 self .use_magnitude_channels = use_magnitude_channels
13641373
@@ -1424,8 +1433,8 @@ def encode(
14241433 xs += [x ]
14251434 info = dict (xs = xs )
14261435
1427- if exists ( self .bottleneck ) :
1428- x , info_bottleneck = self . bottleneck (x , with_info = True )
1436+ for bottleneck in self .bottlenecks :
1437+ x , info_bottleneck = bottleneck (x , with_info = True )
14291438 info = {** info , ** info_bottleneck }
14301439
14311440 return (x , info ) if with_info else x
0 commit comments