11from typing import Optional , Sequence
22
3+ import torch
34from torch import Tensor , nn
45
56from .diffusion import (
@@ -111,7 +112,7 @@ def __init__(self, *args, **kwargs):
111112 use_skip_scale = True ,
112113 diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
113114 diffusion_sigma_data = 0.1 ,
114- diffusion_dynamic_threshold = 0.95 ,
115+ diffusion_dynamic_threshold = 0.0 ,
115116 )
116117
117118 super ().__init__ (* args , ** {** default_kwargs , ** kwargs })
@@ -122,3 +123,50 @@ def sample(self, *args, **kwargs):
122123 sampler = ADPM2Sampler (rho = 1.0 ),
123124 )
124125 return super ().sample (* args , ** {** default_kwargs , ** kwargs })
126+
127+
128+ class AudioDiffusionUpsampler (Model1d ):
129+ def __init__ (self , factor : int , in_channels : int = 1 , * args , ** kwargs ):
130+ self .factor = factor
131+
132+ default_kwargs = dict (
133+ in_channels = in_channels ,
134+ channels = 128 ,
135+ patch_size = 16 ,
136+ kernel_sizes_init = [1 , 3 , 7 ],
137+ multipliers = [1 , 2 , 4 , 4 , 4 , 4 , 4 ],
138+ factors = [4 , 4 , 4 , 2 , 2 , 2 ],
139+ num_blocks = [2 , 2 , 2 , 2 , 2 , 2 ],
140+ attentions = [False , False , False , True , True , True ],
141+ attention_heads = 8 ,
142+ attention_features = 64 ,
143+ attention_multiplier = 2 ,
144+ use_attention_bottleneck = True ,
145+ resnet_groups = 8 ,
146+ kernel_multiplier_downsample = 2 ,
147+ use_nearest_upsample = False ,
148+ use_skip_scale = True ,
149+ diffusion_sigma_distribution = LogNormalDistribution (mean = - 3.0 , std = 1.0 ),
150+ diffusion_sigma_data = 0.1 ,
151+ diffusion_dynamic_threshold = 0.0 ,
152+ context_channels = [in_channels ],
153+ )
154+
155+ super ().__init__ (* args , {** default_kwargs , ** kwargs }) # type: ignore
156+
157+ def forward (self , x : Tensor , ** kwargs ) -> Tensor :
158+ # Downsample by picking every `factor` item
159+ downsampled = x [:, :, :: self .factor ]
160+ # Upsample by interleaving to get context
161+ context = torch .repeat_interleave (downsampled , repeats = self .factor , dim = 2 )
162+ return self .diffusion (x , context = [context ], ** kwargs )
163+
164+ def sample (self , start : Tensor , * args , ** kwargs ): # type: ignore
165+ context = torch .repeat_interleave (start , repeats = self .factor , dim = 2 )
166+ noise = torch .randn_like (context )
167+ default_kwargs = dict (
168+ context = [context ],
169+ sigma_schedule = KarrasSchedule (sigma_min = 0.0001 , sigma_max = 3.0 , rho = 9.0 ),
170+ sampler = ADPM2Sampler (rho = 1.0 ),
171+ )
172+ return super ().sample (noise , * args , ** {** default_kwargs , ** kwargs }) # type: ignore # noqa
0 commit comments