@@ -1364,7 +1364,7 @@ def forward(
13641364 return (x , info ) if with_info else x
13651365
13661366
1367- class AutoEncoder1d (nn .Module ):
1367+ class Encoder1d (nn .Module ):
13681368 def __init__ (
13691369 self ,
13701370 in_channels : int ,
@@ -1375,16 +1375,10 @@ def __init__(
13751375 multipliers : Sequence [int ],
13761376 factors : Sequence [int ],
13771377 num_blocks : Sequence [int ],
1378- use_noisy : bool = False ,
1379- bottleneck : Union [Bottleneck , List [Bottleneck ]] = [],
1380- use_magnitude_channels : bool = False ,
1378+ out_channels : Optional [int ] = None ,
13811379 ):
13821380 super ().__init__ ()
13831381 num_layers = len (multipliers ) - 1
1384- self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
1385- self .use_noisy = use_noisy
1386- self .use_magnitude_channels = use_magnitude_channels
1387-
13881382 assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
13891383
13901384 self .to_in = Patcher (
@@ -1408,10 +1402,66 @@ def __init__(
14081402 ]
14091403 )
14101404
1405+ self .to_out = (
1406+ nn .Conv1d (
1407+ in_channels = channels * multipliers [- 1 ],
1408+ out_channels = out_channels ,
1409+ kernel_size = 1 ,
1410+ )
1411+ if exists (out_channels )
1412+ else nn .Identity ()
1413+ )
1414+
1415+ def forward (
1416+ self , x : Tensor , with_info : bool = False
1417+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1418+ xs = []
1419+ x = self .to_in (x )
1420+
1421+ for downsample in self .downsamples :
1422+ x = downsample (x )
1423+ xs += [x ]
1424+
1425+ x = self .to_out (x )
1426+
1427+ info = dict (xs = xs )
1428+ return (x , info ) if with_info else x
1429+
1430+
1431+ class Decoder1d (nn .Module ):
1432+ def __init__ (
1433+ self ,
1434+ out_channels : int ,
1435+ channels : int ,
1436+ patch_blocks : int ,
1437+ patch_factor : int ,
1438+ resnet_groups : int ,
1439+ multipliers : Sequence [int ],
1440+ factors : Sequence [int ],
1441+ num_blocks : Sequence [int ],
1442+ use_magnitude_channels : bool = False ,
1443+ in_channels : Optional [int ] = None ,
1444+ ):
1445+ super ().__init__ ()
1446+ num_layers = len (multipliers ) - 1
1447+ self .use_magnitude_channels = use_magnitude_channels
1448+
1449+ assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1450+
1451+ self .to_in = (
1452+ Conv1d (
1453+ in_channels = in_channels ,
1454+ out_channels = channels * multipliers [- 1 ],
1455+ kernel_size = 1 ,
1456+ )
1457+ if exists (in_channels )
1458+ else nn .Identity ()
1459+ )
1460+
14111461 self .upsamples = nn .ModuleList (
14121462 [
14131463 UpsampleBlock1d (
1414- in_channels = channels * multipliers [i + 1 ] * ( use_noisy + 1 ) ,
1464+ in_channels = channels * multipliers [i + 1 ],
14151465 out_channels = channels * multipliers [i ],
14161466 factor = factors [i ],
14171467 num_groups = resnet_groups ,
@@ -1424,12 +1474,73 @@ def __init__(
14241474 )
14251475
14261476 self .to_out = Unpatcher (
1427- in_channels = channels * ( use_noisy + 1 ) ,
1428- out_channels = in_channels * (2 if use_magnitude_channels else 1 ),
1477+ in_channels = channels ,
1478+ out_channels = out_channels * (2 if use_magnitude_channels else 1 ),
14291479 blocks = patch_blocks ,
14301480 factor = patch_factor ,
14311481 )
14321482
1483+ def forward (self , x : Tensor ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1484+ x = self .to_in (x )
1485+
1486+ for upsample in self .upsamples :
1487+ x = upsample (x )
1488+
1489+ x = self .to_out (x )
1490+
1491+ if self .use_magnitude_channels :
1492+ x = merge_magnitude_channels (x )
1493+
1494+ return x
1495+
1496+
1497+ class AutoEncoder1d (nn .Module ):
1498+ def __init__ (
1499+ self ,
1500+ in_channels : int ,
1501+ channels : int ,
1502+ patch_blocks : int ,
1503+ patch_factor : int ,
1504+ resnet_groups : int ,
1505+ multipliers : Sequence [int ],
1506+ factors : Sequence [int ],
1507+ num_blocks : Sequence [int ],
1508+ use_noisy : bool = False ,
1509+ bottleneck : Union [Bottleneck , List [Bottleneck ]] = [],
1510+ bottleneck_channels : Optional [int ] = None ,
1511+ use_magnitude_channels : bool = False ,
1512+ ):
1513+ super ().__init__ ()
1514+ num_layers = len (multipliers ) - 1
1515+ self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
1516+
1517+ assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
1518+
1519+ self .encoder = Encoder1d (
1520+ in_channels = in_channels ,
1521+ channels = channels ,
1522+ patch_blocks = patch_blocks ,
1523+ patch_factor = patch_factor ,
1524+ resnet_groups = resnet_groups ,
1525+ multipliers = multipliers ,
1526+ factors = factors ,
1527+ num_blocks = num_blocks ,
1528+ out_channels = bottleneck_channels ,
1529+ )
1530+
1531+ self .decoder = Decoder1d (
1532+ in_channels = bottleneck_channels ,
1533+ out_channels = in_channels ,
1534+ channels = channels ,
1535+ patch_blocks = patch_blocks ,
1536+ patch_factor = patch_factor ,
1537+ resnet_groups = resnet_groups ,
1538+ multipliers = multipliers ,
1539+ factors = factors ,
1540+ num_blocks = num_blocks ,
1541+ use_magnitude_channels = use_magnitude_channels ,
1542+ )
1543+
14331544 def forward (
14341545 self , x : Tensor , with_info : bool = False
14351546 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
@@ -1440,12 +1551,7 @@ def forward(
14401551 def encode (
14411552 self , x : Tensor , with_info : bool = False
14421553 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1443- xs = []
1444- x = self .to_in (x )
1445- for downsample in self .downsamples :
1446- x = downsample (x )
1447- xs += [x ]
1448- info = dict (xs = xs )
1554+ x , info = self .encoder (x , with_info = True )
14491555
14501556 for bottleneck in self .bottlenecks :
14511557 x , info_bottleneck = bottleneck (x , with_info = True )
@@ -1454,20 +1560,7 @@ def encode(
14541560 return (x , info ) if with_info else x
14551561
14561562 def decode (self , x : Tensor ) -> Tensor :
1457- for upsample in self .upsamples :
1458- if self .use_noisy :
1459- x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1460- x = upsample (x )
1461-
1462- if self .use_noisy :
1463- x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1464-
1465- x = self .to_out (x )
1466-
1467- if self .use_magnitude_channels :
1468- x = merge_magnitude_channels (x )
1469-
1470- return x
1563+ return self .decoder (x )
14711564
14721565
14731566class MultiEncoder1d (nn .Module ):
0 commit comments