66from einops import rearrange
77from torch import Tensor , nn
88
9- from .diffusion import (
10- DiffusionSampler ,
11- KDiffusion ,
12- LinearSchedule ,
13- Sampler ,
14- Schedule ,
15- UniformDistribution ,
16- VDiffusion ,
17- VKDiffusion ,
18- VSampler ,
19- )
20- from .modules import STFT , SinusoidalEmbedding , UNet1d , UNetConditional1d
9+ from .diffusion import LinearSchedule , UniformDistribution , VSampler , XDiffusion
10+ from .modules import STFT , Conv1d , SinusoidalEmbedding , UNet1d , UNetConditional1d
2111from .utils import (
2212 default ,
2313 downsample ,
@@ -44,36 +34,15 @@ def __init__(
4434 UNet = UNetConditional1d if use_classifier_free_guidance else UNet1d
4535 self .unet = UNet (** kwargs )
4636
47- # Check valid diffusion type
48- diffusion_classes = [VDiffusion , KDiffusion , VKDiffusion ]
49- aliases = [t .alias for t in diffusion_classes ] # type: ignore
50- message = f"diffusion_type='{ diffusion_type } ' must be one of { * aliases ,} "
51- assert diffusion_type in aliases , message
52-
53- for XDiffusion in diffusion_classes :
54- if XDiffusion .alias == diffusion_type : # type: ignore
55- self .diffusion = XDiffusion (net = self .unet , ** diffusion_kwargs )
37+ self .diffusion = XDiffusion (
38+ type = diffusion_type , net = self .unet , ** diffusion_kwargs
39+ )
5640
5741 def forward (self , x : Tensor , ** kwargs ) -> Tensor :
5842 return self .diffusion (x , ** kwargs )
5943
60- def sample (
61- self ,
62- noise : Tensor ,
63- num_steps : int ,
64- sigma_schedule : Schedule ,
65- sampler : Sampler ,
66- clamp : bool ,
67- ** kwargs ,
68- ) -> Tensor :
69- diffusion_sampler = DiffusionSampler (
70- diffusion = self .diffusion ,
71- sampler = sampler ,
72- sigma_schedule = sigma_schedule ,
73- num_steps = num_steps ,
74- clamp = clamp ,
75- )
76- return diffusion_sampler (noise , ** kwargs )
44+ def sample (self , * args , ** kwargs ) -> Tensor :
45+ return self .diffusion .sample (* args , ** kwargs )
7746
7847
7948class DiffusionUpsampler1d (Model1d ):
@@ -139,69 +108,70 @@ def sample( # type: ignore
139108 return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
140109
141110
142- class DiffusionAutoencoder1d (Model1d ):
111+ class DiffusionAutoencoder1d (nn . Module ):
143112 def __init__ (
144113 self ,
145114 in_channels : int ,
146- channels : int ,
147- patch_blocks : int ,
148- patch_factor : int ,
149- multipliers : Sequence [int ],
150- factors : Sequence [int ],
151- num_blocks : Sequence [int ],
152- resnet_groups : int ,
153- kernel_multiplier_downsample : int ,
154- encoder_depth : int ,
155- encoder_num_blocks : Optional [Sequence [int ]] = None ,
115+ encoder_inject_depth : int ,
116+ encoder_channels : int ,
117+ encoder_factors : Sequence [int ],
118+ encoder_multipliers : Sequence [int ],
119+ diffusion_type : str ,
120+ encoder_patch_size : int = 1 ,
156121 bottleneck : Union [Bottleneck , Sequence [Bottleneck ]] = [],
157122 bottleneck_channels : Optional [int ] = None ,
158- use_stft : bool = False ,
159123 ** kwargs ,
160124 ):
125+ super ().__init__ ()
161126 self .in_channels = in_channels
162- encoder_num_blocks = default (encoder_num_blocks , num_blocks )
163- assert_message = "The number of encoder_num_blocks must match encoder_depth"
164- assert len (encoder_num_blocks ) >= encoder_depth , assert_message
165- assert patch_blocks == 1 , "patch_blocks != 1 not supported"
166- assert not use_stft , "use_stft not supported"
167- self .factor = patch_factor * prod (factors [0 :encoder_depth ])
168-
169- context_channels = [0 ] * encoder_depth
127+
128+ encoder_kwargs , kwargs = groupby ("encoder_" , kwargs )
129+ diffusion_kwargs , kwargs = groupby ("diffusion_" , kwargs )
130+
131+ # Compute context channels
132+ context_channels = [0 ] * encoder_inject_depth
170133 if exists (bottleneck_channels ):
171134 context_channels += [bottleneck_channels ]
172135 else :
173- context_channels += [channels * multipliers [ encoder_depth ]]
136+ context_channels += [encoder_channels * encoder_multipliers [ - 1 ]]
174137
175- super ().__init__ (
176- in_channels = in_channels ,
177- channels = channels ,
178- patch_blocks = patch_blocks ,
179- patch_factor = patch_factor ,
180- multipliers = multipliers ,
181- factors = factors ,
182- num_blocks = num_blocks ,
183- resnet_groups = resnet_groups ,
184- kernel_multiplier_downsample = kernel_multiplier_downsample ,
185- context_channels = context_channels ,
186- ** kwargs ,
138+ self .unet = UNet1d (
139+ in_channels = in_channels , context_channels = context_channels , ** kwargs
140+ )
141+
142+ self .diffusion = XDiffusion (
143+ type = diffusion_type , net = self .unet , ** diffusion_kwargs
187144 )
188145
189- self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
190146 self .encoder = Encoder1d (
191147 in_channels = in_channels ,
192- channels = channels ,
193- patch_size = patch_factor ,
194- multipliers = multipliers [0 : encoder_depth + 1 ],
195- factors = factors [0 :encoder_depth ],
196- num_blocks = encoder_num_blocks [0 :encoder_depth ],
197- resnet_groups = resnet_groups ,
148+ channels = encoder_channels ,
149+ patch_size = encoder_patch_size ,
150+ factors = encoder_factors ,
151+ multipliers = encoder_multipliers ,
198152 out_channels = bottleneck_channels ,
153+ ** encoder_kwargs ,
199154 )
200155
156+ if exists (bottleneck_channels ):
157+ self .to_bottleneck = Conv1d (
158+ in_channels = encoder_channels * encoder_multipliers [- 1 ],
159+ out_channels = bottleneck_channels ,
160+ kernel_size = 1 ,
161+ )
162+
163+ self .encoder_downsample_factor = encoder_patch_size * prod (encoder_factors )
164+ self .bottleneck_channels = bottleneck_channels
165+ self .bottlenecks = nn .ModuleList (to_list (bottleneck ))
166+
201167 def encode (
202168 self , x : Tensor , with_info : bool = False
203169 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
204170 latent , info = self .encoder (x , with_info = True )
171+ # Convert latent channels
172+ if exists (self .bottleneck_channels ):
173+ latent = self .to_bottleneck (latent )
174+ # Apply bottlenecks if present
205175 for bottleneck in self .bottlenecks :
206176 latent , info_bottleneck = bottleneck (latent , with_info = True )
207177 info = {** info , ** prefix_dict ("bottleneck_" , info_bottleneck )}
@@ -215,13 +185,17 @@ def forward( # type: ignore
215185 return (loss , info ) if with_info else loss
216186
217187 def decode (self , latent : Tensor , ** kwargs ) -> Tensor :
218- b , length = latent .shape [0 ], latent .shape [2 ] * self .factor
188+ b = latent .shape [0 ]
189+ length = latent .shape [2 ] * self .encoder_downsample_factor
219190 # Compute noise by inferring shape from latent length
220- noise = torch .randn (b , self .in_channels , length ). to ( latent )
191+ noise = torch .randn (b , self .in_channels , length , device = latent . device )
221192 # Compute context form latent
222193 default_kwargs = dict (channels_list = [latent ])
223194 # Decode by sampling while conditioning on latent channels
224- return super ().sample (noise , ** {** default_kwargs , ** kwargs }) # type: ignore
195+ return self .sample (noise , ** {** default_kwargs , ** kwargs })
196+
197+ def sample (self , * args , ** kwargs ) -> Tensor :
198+ return self .diffusion .sample (* args , ** kwargs )
225199
226200
227201class DiffusionVocoder1d (Model1d ):
@@ -339,7 +313,14 @@ def sample(self, *args, **kwargs):
339313class AudioDiffusionAutoencoder (DiffusionAutoencoder1d ):
340314 def __init__ (self , * args , ** kwargs ):
341315 default_kwargs = dict (
342- ** get_default_model_kwargs (), encoder_depth = 4 , encoder_channels = 64
316+ ** get_default_model_kwargs (),
317+ encoder_inject_depth = 6 ,
318+ encoder_channels = 16 ,
319+ encoder_patch_size = 16 ,
320+ encoder_multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
321+ encoder_factors = [4 , 4 , 4 , 2 , 2 , 2 ],
322+ encoder_num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
323+ bottleneck_channels = 64 ,
343324 )
344325 super ().__init__ (* args , ** {** default_kwargs , ** kwargs }) # type: ignore
345326
@@ -398,7 +379,6 @@ def __init__(self, in_channels: int, **kwargs):
398379 use_nearest_upsample = False ,
399380 use_skip_scale = True ,
400381 use_context_time = True ,
401- use_magnitude_channels = False ,
402382 diffusion_type = "v" ,
403383 diffusion_sigma_distribution = UniformDistribution (),
404384 )
0 commit comments