99from tqdm import tqdm
1010
1111from .diffusion import LinearSchedule , UniformDistribution , VSampler , XDiffusion
12- from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetCFG1d , rand_bool
12+ from .modules import STFT , SinusoidalEmbedding , XUNet1d , rand_bool
1313from .utils import (
1414 closest_power_2 ,
1515 default ,
2828
2929
3030class Model1d (nn .Module ):
31- def __init__ (
32- self , diffusion_type : str , use_classifier_free_guidance : bool = False , ** kwargs
33- ):
31+ def __init__ (self , unet_type : str = "base" , ** kwargs ):
3432 super ().__init__ ()
3533 diffusion_kwargs , kwargs = groupby ("diffusion_" , kwargs )
36-
37- UNet = UNetCFG1d if use_classifier_free_guidance else UNet1d
38- self .unet = UNet (** kwargs )
39-
40- self .diffusion = XDiffusion (
41- type = diffusion_type , net = self .unet , ** diffusion_kwargs
42- )
34+ self .unet = XUNet1d (type = unet_type , ** kwargs )
35+ self .diffusion = XDiffusion (net = self .unet , ** diffusion_kwargs )
4336
4437 def forward (self , x : Tensor , ** kwargs ) -> Tensor :
4538 return self .diffusion (x , ** kwargs )
@@ -119,10 +112,10 @@ def __init__(
119112 encoder_channels : int ,
120113 encoder_factors : Sequence [int ],
121114 encoder_multipliers : Sequence [int ],
122- diffusion_type : str ,
123115 encoder_patch_size : int = 1 ,
124116 bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
125117 bottleneck_channels : Optional [int ] = None ,
118+ unet_type : str = "base" ,
126119 ** kwargs ,
127120 ):
128121 super ().__init__ ()
@@ -138,13 +131,14 @@ def __init__(
138131 else :
139132 context_channels += [encoder_channels * encoder_multipliers [- 1 ]]
140133
141- self .unet = UNet1d (
142- in_channels = in_channels , context_channels = context_channels , ** kwargs
134+ self .unet = XUNet1d (
135+ type = unet_type ,
136+ in_channels = in_channels ,
137+ context_channels = context_channels ,
138+ ** kwargs ,
143139 )
144140
145- self .diffusion = XDiffusion (
146- type = diffusion_type , net = self .unet , ** diffusion_kwargs
147- )
141+ self .diffusion = XDiffusion (net = self .unet , ** diffusion_kwargs )
148142
149143 self .encoder = Encoder1d (
150144 in_channels = in_channels ,
@@ -207,6 +201,7 @@ def __init__(
207201 encoder_patch_size : int = 1 ,
208202 bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
209203 bottleneck_channels : Optional [int ] = None ,
204+ unet_type : str = "base" ,
210205 ** kwargs ,
211206 ):
212207 super ().__init__ ()
@@ -233,7 +228,8 @@ def __init__(
233228 use_complex = False , # Magnitude encoding
234229 )
235230
236- self .unet = UNet1d (
231+ self .unet = XUNet1d (
232+ type = unet_type ,
237233 in_channels = in_channels ,
238234 context_channels = context_channels ,
239235 use_stft = True ,
@@ -546,9 +542,9 @@ def __init__(
546542 self .embedding_mask_proba = embedding_mask_proba
547543 default_kwargs = dict (
548544 ** get_default_model_kwargs (),
545+ unet_type = "cfg" ,
549546 context_embedding_features = embedding_features ,
550547 context_embedding_max_length = embedding_max_length ,
551- use_classifier_free_guidance = True ,
552548 )
553549 super ().__init__ (** {** default_kwargs , ** kwargs })
554550
0 commit comments