|
9 | 9 | from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion |
10 | 10 | from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d |
11 | 11 | from .utils import ( |
| 12 | + closest_power_2, |
12 | 13 | default, |
13 | 14 | downsample, |
14 | 15 | exists, |
@@ -188,6 +189,100 @@ def sample(self, *args, **kwargs) -> Tensor: |
188 | 189 | return self.diffusion.sample(*args, **kwargs) |
189 | 190 |
|
190 | 191 |
|
| 192 | +class DiffusionMAE1d(nn.Module): |
| 193 | + def __init__( |
| 194 | + self, |
| 195 | + in_channels: int, |
| 196 | + encoder_inject_depth: int, |
| 197 | + encoder_channels: int, |
| 198 | + encoder_factors: Sequence[int], |
| 199 | + encoder_multipliers: Sequence[int], |
| 200 | + diffusion_type: str, |
| 201 | + stft_num_fft: int, |
| 202 | + encoder_patch_size: int = 1, |
| 203 | + bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [], |
| 204 | + bottleneck_channels: Optional[int] = None, |
| 205 | + **kwargs, |
| 206 | + ): |
| 207 | + super().__init__() |
| 208 | + self.in_channels = in_channels |
| 209 | + |
| 210 | + encoder_kwargs, kwargs = groupby("encoder_", kwargs) |
| 211 | + diffusion_kwargs, kwargs = groupby("diffusion_", kwargs) |
| 212 | + |
| 213 | + # Compute context channels |
| 214 | + context_channels = [0] * encoder_inject_depth |
| 215 | + if exists(bottleneck_channels): |
| 216 | + context_channels += [bottleneck_channels] |
| 217 | + else: |
| 218 | + context_channels += [encoder_channels * encoder_multipliers[-1]] |
| 219 | + |
| 220 | + self.spectrogram_channels = stft_num_fft // 2 + 1 |
| 221 | + |
| 222 | + self.unet = UNet1d( |
| 223 | + in_channels=in_channels, |
| 224 | + stft_num_fft=stft_num_fft, |
| 225 | + context_channels=context_channels, |
| 226 | + use_stft=True, |
| 227 | + **kwargs, |
| 228 | + ) |
| 229 | + |
| 230 | + self.stft = self.unet.stft |
| 231 | + |
| 232 | + self.diffusion = XDiffusion( |
| 233 | + type=diffusion_type, net=self.unet, **diffusion_kwargs |
| 234 | + ) |
| 235 | + |
| 236 | + self.encoder = Encoder1d( |
| 237 | + in_channels=in_channels * self.spectrogram_channels, |
| 238 | + channels=encoder_channels, |
| 239 | + patch_size=encoder_patch_size, |
| 240 | + factors=encoder_factors, |
| 241 | + multipliers=encoder_multipliers, |
| 242 | + out_channels=bottleneck_channels, |
| 243 | + **encoder_kwargs, |
| 244 | + ) |
| 245 | + |
| 246 | + self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors) |
| 247 | + self.bottleneck_channels = bottleneck_channels |
| 248 | + self.bottlenecks = nn.ModuleList(to_list(bottleneck)) |
| 249 | + |
| 250 | + def encode( |
| 251 | + self, x: Tensor, with_info: bool = False |
| 252 | + ) -> Union[Tensor, Tuple[Tensor, Any]]: |
| 253 | + # Extract magnitude and encode |
| 254 | + magnitude, _ = self.stft.encode(x) |
| 255 | + magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t") |
| 256 | + latent, info = self.encoder(magnitude_flat, with_info=True) |
| 257 | + # Apply bottlenecks if present |
| 258 | + for bottleneck in self.bottlenecks: |
| 259 | + latent, info_bottleneck = bottleneck(latent, with_info=True) |
| 260 | + info = {**info, **prefix_dict("bottleneck_", info_bottleneck)} |
| 261 | + return (latent, info) if with_info else latent |
| 262 | + |
| 263 | + def forward( # type: ignore |
| 264 | + self, x: Tensor, with_info: bool = False, **kwargs |
| 265 | + ) -> Union[Tensor, Tuple[Tensor, Any]]: |
| 266 | + latent, info = self.encode(x, with_info=True) |
| 267 | + loss = self.diffusion(x, channels_list=[latent], **kwargs) |
| 268 | + return (loss, info) if with_info else loss |
| 269 | + |
| 270 | + def decode(self, latent: Tensor, **kwargs) -> Tensor: |
| 271 | + b = latent.shape[0] |
| 272 | + length = closest_power_2( |
| 273 | + self.stft.hop_length * latent.shape[2] * self.encoder_downsample_factor |
| 274 | + ) |
| 275 | + # Compute noise by inferring shape from latent length |
| 276 | + noise = torch.randn(b, self.in_channels, length, device=latent.device) |
| 277 | + # Compute context form latent |
| 278 | + default_kwargs = dict(channels_list=[latent]) |
| 279 | + # Decode by sampling while conditioning on latent channels |
| 280 | + return self.sample(noise, **{**default_kwargs, **kwargs}) # type: ignore |
| 281 | + |
| 282 | + def sample(self, *args, **kwargs) -> Tensor: |
| 283 | + return self.diffusion.sample(*args, **kwargs) |
| 284 | + |
| 285 | + |
191 | 286 | class DiffusionVocoder1d(Model1d): |
192 | 287 | def __init__( |
193 | 288 | self, |
@@ -318,6 +413,27 @@ def decode(self, *args, **kwargs): |
318 | 413 | return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs}) |
319 | 414 |
|
320 | 415 |
|
| 416 | +class AudioDiffusionMAE(DiffusionMAE1d): |
| 417 | + def __init__(self, *args, **kwargs): |
| 418 | + default_kwargs = dict( |
| 419 | + patch_blocks=1, |
| 420 | + patch_factor=1, |
| 421 | + resnet_groups=8, |
| 422 | + kernel_multiplier_downsample=2, |
| 423 | + use_nearest_upsample=False, |
| 424 | + use_skip_scale=True, |
| 425 | + use_context_time=True, |
| 426 | + diffusion_type="v", |
| 427 | + diffusion_sigma_distribution=UniformDistribution(), |
| 428 | + stft_num_fft=1023, |
| 429 | + stft_hop_length=256, |
| 430 | + ) |
| 431 | + super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore |
| 432 | + |
| 433 | + def decode(self, *args, **kwargs): |
| 434 | + return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs}) |
| 435 | + |
| 436 | + |
321 | 437 | class AudioDiffusionConditional(Model1d): |
322 | 438 | def __init__( |
323 | 439 | self, |
|
0 commit comments