88from einops_exts import rearrange_many
99from torch import Tensor , einsum
1010
11- from .utils import default , exists , prod
11+ from .utils import default , exists , prod , wave_norm , wave_unnorm
1212
1313"""
1414Utils
@@ -809,6 +809,7 @@ def __init__(
809809 use_nearest_upsample : bool ,
810810 use_skip_scale : bool ,
811811 use_context_time : bool ,
812+ norm : float = 0.0 ,
812813 out_channels : Optional [int ] = None ,
813814 context_features : Optional [int ] = None ,
814815 context_channels : Optional [Sequence [int ]] = None ,
@@ -822,6 +823,8 @@ def __init__(
822823 use_context_channels = len (context_channels ) > 0
823824 context_mapping_features = None
824825
826+ self .norm = norm
827+ self .use_norm = norm > 0.0
825828 self .num_layers = num_layers
826829 self .use_context_time = use_context_time
827830 self .use_context_features = use_context_features
@@ -997,9 +1000,11 @@ def forward(
9971000 # Concat context channels at layer 0 if provided
9981001 channels = self .get_channels (channels_list , layer = 0 )
9991002 x = torch .cat ([x , channels ], dim = 1 ) if exists (channels ) else x
1000-
10011003 mapping = self .get_mapping (time , features )
10021004
1005+ if self .use_norm :
1006+ x = wave_norm (x , peak = self .norm )
1007+
10031008 x = self .to_in (x , mapping )
10041009 skips_list = [x ]
10051010
@@ -1019,6 +1024,9 @@ def forward(
10191024 x += skips_list .pop ()
10201025 x = self .to_out (x , mapping )
10211026
1027+ if self .use_norm :
1028+ x = wave_unnorm (x , peak = self .norm )
1029+
10221030 return x
10231031
10241032
@@ -1120,11 +1128,14 @@ def __init__(
11201128 num_blocks : Sequence [int ],
11211129 use_noisy : bool = False ,
11221130 bottleneck : Optional [Bottleneck ] = None ,
1131+ norm : float = 0.0 ,
11231132 ):
11241133 super ().__init__ ()
11251134 num_layers = len (multipliers ) - 1
11261135 self .bottleneck = bottleneck
11271136 self .use_noisy = use_noisy
1137+ self .use_norm = norm > 0.0
1138+ self .norm = norm
11281139
11291140 assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
11301141
@@ -1174,6 +1185,9 @@ def __init__(
11741185 def encode (
11751186 self , x : Tensor , with_info : bool = False
11761187 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1188+ if self .use_norm :
1189+ x = wave_norm (x , peak = self .norm )
1190+
11771191 x = self .to_in (x )
11781192 for downsample in self .downsamples :
11791193 x = downsample (x )
@@ -1190,7 +1204,12 @@ def decode(self, x: Tensor) -> Tensor:
11901204 x = upsample (x )
11911205 if self .use_noisy :
11921206 x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1193- return self .to_out (x )
1207+ x = self .to_out (x )
1208+
1209+ if self .use_norm :
1210+ x = wave_unnorm (x , peak = self .norm )
1211+
1212+ return x
11941213
11951214
11961215class MultiEncoder1d (nn .Module ):
0 commit comments