88from einops_exts import rearrange_many
99from torch import Tensor , einsum
1010
11- from .utils import default , exists , prod , wave_norm , wave_unnorm
11+ from .utils import default , exists , prod
1212
1313"""
1414Utils
@@ -785,6 +785,15 @@ def forward(
785785 return x
786786
787787
788+ def get_norm_scale (x : Tensor , quantile : float ):
789+ return torch .quantile (x .abs (), quantile , dim = - 1 , keepdim = True ) + 1e-7
790+
791+
792+ def merge_magnitude_channels (x : Tensor ):
793+ waveform , magnitude = torch .chunk (x , chunks = 2 , dim = 1 )
794+ return torch .sigmoid (waveform ) * torch .tanh (magnitude )
795+
796+
788797"""
789798UNet
790799"""
@@ -809,8 +818,8 @@ def __init__(
809818 use_nearest_upsample : bool ,
810819 use_skip_scale : bool ,
811820 use_context_time : bool ,
812- norm : float = 0.0 ,
813- norm_alpha : float = 20 .0 ,
821+ use_magnitude_channels : bool ,
822+ norm_quantile : float = 0 .0 ,
814823 out_channels : Optional [int ] = None ,
815824 context_features : Optional [int ] = None ,
816825 context_channels : Optional [Sequence [int ]] = None ,
@@ -824,9 +833,6 @@ def __init__(
824833 use_context_channels = len (context_channels ) > 0
825834 context_mapping_features = None
826835
827- self .use_norm = norm > 0.0
828- self .norm = norm
829- self .norm_alpha = norm_alpha
830836 self .num_layers = num_layers
831837 self .use_context_time = use_context_time
832838 self .use_context_features = use_context_features
@@ -841,6 +847,10 @@ def __init__(
841847 self .has_context = has_context
842848 self .channels_ids = [sum (has_context [:i ]) for i in range (len (has_context ))]
843849
850+ self .use_norm = norm_quantile > 0.0
851+ self .norm_quantile = norm_quantile
852+ self .use_magnitude_channels = use_magnitude_channels
853+
844854 assert (
845855 len (factors ) == num_layers
846856 and len (attentions ) >= num_layers
@@ -943,7 +953,7 @@ def __init__(
943953
944954 self .to_out = Unpatcher (
945955 in_channels = channels ,
946- out_channels = out_channels ,
956+ out_channels = out_channels * ( 2 if use_magnitude_channels else 1 ) ,
947957 blocks = patch_blocks ,
948958 factor = patch_factor ,
949959 context_mapping_features = context_mapping_features ,
@@ -1002,10 +1012,11 @@ def forward(
10021012 # Concat context channels at layer 0 if provided
10031013 channels = self .get_channels (channels_list , layer = 0 )
10041014 x = torch .cat ([x , channels ], dim = 1 ) if exists (channels ) else x
1015+ # Compute mapping from time and features
10051016 mapping = self .get_mapping (time , features )
1006-
1007- if self .use_norm :
1008- x = wave_norm ( x , peak = self . norm , alpha = self . norm_alpha )
1017+ # Compute norm scale
1018+ scale = get_norm_scale ( x , self . norm_quantile ) if self .use_norm else 1.0
1019+ x = x / scale
10091020
10101021 x = self .to_in (x , mapping )
10111022 skips_list = [x ]
@@ -1026,10 +1037,10 @@ def forward(
10261037 x += skips_list .pop ()
10271038 x = self .to_out (x , mapping )
10281039
1029- if self .use_norm :
1030- x = wave_unnorm ( x , peak = self . norm , alpha = self . norm_alpha )
1040+ if self .use_magnitude_channels :
1041+ x = merge_magnitude_channels ( x )
10311042
1032- return x
1043+ return x * scale
10331044
10341045
10351046class FixedEmbedding (nn .Module ):
@@ -1130,16 +1141,13 @@ def __init__(
11301141 num_blocks : Sequence [int ],
11311142 use_noisy : bool = False ,
11321143 bottleneck : Optional [Bottleneck ] = None ,
1133- norm : float = 0.0 ,
1134- norm_alpha : float = 20.0 ,
1144+ use_magnitude_channels : bool = False ,
11351145 ):
11361146 super ().__init__ ()
11371147 num_layers = len (multipliers ) - 1
11381148 self .bottleneck = bottleneck
11391149 self .use_noisy = use_noisy
1140- self .use_norm = norm > 0.0
1141- self .norm = norm
1142- self .norm_alpha = norm_alpha
1150+ self .use_magnitude_channels = use_magnitude_channels
11431151
11441152 assert len (factors ) >= num_layers and len (num_blocks ) >= num_layers
11451153
@@ -1181,16 +1189,14 @@ def __init__(
11811189
11821190 self .to_out = Unpatcher (
11831191 in_channels = channels * (use_noisy + 1 ),
1184- out_channels = in_channels ,
1192+ out_channels = in_channels * ( 2 if use_magnitude_channels else 1 ) ,
11851193 blocks = patch_blocks ,
11861194 factor = patch_factor ,
11871195 )
11881196
11891197 def encode (
11901198 self , x : Tensor , with_info : bool = False
11911199 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
1192- if self .use_norm :
1193- x = wave_norm (x , peak = self .norm , alpha = self .norm_alpha )
11941200
11951201 x = self .to_in (x )
11961202 for downsample in self .downsamples :
@@ -1206,12 +1212,14 @@ def decode(self, x: Tensor) -> Tensor:
12061212 if self .use_noisy :
12071213 x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
12081214 x = upsample (x )
1215+
12091216 if self .use_noisy :
12101217 x = torch .cat ([x , torch .randn_like (x )], dim = 1 )
1218+
12111219 x = self .to_out (x )
12121220
1213- if self .use_norm :
1214- x = wave_unnorm ( x , peak = self . norm , alpha = self . norm_alpha )
1221+ if self .use_magnitude_channels :
1222+ x = merge_magnitude_channels ( x )
12151223
12161224 return x
12171225
0 commit comments