@@ -810,6 +810,7 @@ def __init__(
810810 use_skip_scale : bool ,
811811 use_context_time : bool ,
812812 norm : float = 0.0 ,
813+ norm_alpha : float = 20.0 ,
813814 out_channels : Optional [int ] = None ,
814815 context_features : Optional [int ] = None ,
815816 context_channels : Optional [Sequence [int ]] = None ,
@@ -823,8 +824,9 @@ def __init__(
823824 use_context_channels = len (context_channels ) > 0
824825 context_mapping_features = None
825826
826- self .norm = norm
827827 self .use_norm = norm > 0.0
828+ self .norm = norm
829+ self .norm_alpha = norm_alpha
828830 self .num_layers = num_layers
829831 self .use_context_time = use_context_time
830832 self .use_context_features = use_context_features
@@ -1003,7 +1005,7 @@ def forward(
10031005 mapping = self .get_mapping (time , features )
10041006
10051007 if self .use_norm :
1006- x = wave_norm (x , peak = self .norm )
1008+ x = wave_norm (x , peak = self .norm , alpha = self . norm_alpha )
10071009
10081010 x = self .to_in (x , mapping )
10091011 skips_list = [x ]
@@ -1025,7 +1027,7 @@ def forward(
10251027 x = self .to_out (x , mapping )
10261028
10271029 if self .use_norm :
1028- x = wave_unnorm (x , peak = self .norm )
1030+ x = wave_unnorm (x , peak = self .norm , alpha = self . norm_alpha )
10291031
10301032 return x
10311033
@@ -1129,13 +1131,15 @@ def __init__(
11291131 use_noisy : bool = False ,
11301132 bottleneck : Optional [Bottleneck ] = None ,
11311133 norm : float = 0.0 ,
1134+ norm_alpha : float = 20.0 ,
11321135 ):
11331136 super ().__init__ ()
11341137 num_layers = len (multipliers ) - 1
11351138 self .bottleneck = bottleneck
11361139 self .use_noisy = use_noisy
11371140 self .use_norm = norm > 0.0
11381141 self .norm = norm
1142+ self .norm_alpha = norm_alpha
11391143
11401144 assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
11411145
@@ -1186,7 +1190,7 @@ def encode(
11861190 self , x : Tensor , with_info : bool = False
11871191 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
11881192 if self .use_norm :
1189- x = wave_norm (x , peak = self .norm )
1193+ x = wave_norm (x , peak = self .norm , alpha = self . norm_alpha )
11901194
11911195 x = self .to_in (x )
11921196 for downsample in self .downsamples :
@@ -1207,7 +1211,7 @@ def decode(self, x: Tensor) -> Tensor:
12071211 x = self .to_out (x )
12081212
12091213 if self .use_norm :
1210- x = wave_unnorm (x , peak = self .norm )
1214+ x = wave_unnorm (x , peak = self .norm , alpha = self . norm_alpha )
12111215
12121216 return x
12131217
0 commit comments