Skip to content

Commit dbe3b91

Browse files
feat: add diffmae
1 parent b63df7d commit dbe3b91

File tree

2 files changed

+117
-1
lines changed

2 files changed

+117
-1
lines changed

audio_diffusion_pytorch/model.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
1010
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
1111
from .utils import (
12+
closest_power_2,
1213
default,
1314
downsample,
1415
exists,
@@ -188,6 +189,100 @@ def sample(self, *args, **kwargs) -> Tensor:
188189
return self.diffusion.sample(*args, **kwargs)
189190

190191

192+
class DiffusionMAE1d(nn.Module):
193+
def __init__(
194+
self,
195+
in_channels: int,
196+
encoder_inject_depth: int,
197+
encoder_channels: int,
198+
encoder_factors: Sequence[int],
199+
encoder_multipliers: Sequence[int],
200+
diffusion_type: str,
201+
stft_num_fft: int,
202+
encoder_patch_size: int = 1,
203+
bottleneck: Union[Bottleneck, Sequence[Bottleneck]] = [],
204+
bottleneck_channels: Optional[int] = None,
205+
**kwargs,
206+
):
207+
super().__init__()
208+
self.in_channels = in_channels
209+
210+
encoder_kwargs, kwargs = groupby("encoder_", kwargs)
211+
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
212+
213+
# Compute context channels
214+
context_channels = [0] * encoder_inject_depth
215+
if exists(bottleneck_channels):
216+
context_channels += [bottleneck_channels]
217+
else:
218+
context_channels += [encoder_channels * encoder_multipliers[-1]]
219+
220+
self.spectrogram_channels = stft_num_fft // 2 + 1
221+
222+
self.unet = UNet1d(
223+
in_channels=in_channels,
224+
stft_num_fft=stft_num_fft,
225+
context_channels=context_channels,
226+
use_stft=True,
227+
**kwargs,
228+
)
229+
230+
self.stft = self.unet.stft
231+
232+
self.diffusion = XDiffusion(
233+
type=diffusion_type, net=self.unet, **diffusion_kwargs
234+
)
235+
236+
self.encoder = Encoder1d(
237+
in_channels=in_channels * self.spectrogram_channels,
238+
channels=encoder_channels,
239+
patch_size=encoder_patch_size,
240+
factors=encoder_factors,
241+
multipliers=encoder_multipliers,
242+
out_channels=bottleneck_channels,
243+
**encoder_kwargs,
244+
)
245+
246+
self.encoder_downsample_factor = encoder_patch_size * prod(encoder_factors)
247+
self.bottleneck_channels = bottleneck_channels
248+
self.bottlenecks = nn.ModuleList(to_list(bottleneck))
249+
250+
def encode(
251+
self, x: Tensor, with_info: bool = False
252+
) -> Union[Tensor, Tuple[Tensor, Any]]:
253+
# Extract magnitude and encode
254+
magnitude, _ = self.stft.encode(x)
255+
magnitude_flat = rearrange(magnitude, "b c f t -> b (c f) t")
256+
latent, info = self.encoder(magnitude_flat, with_info=True)
257+
# Apply bottlenecks if present
258+
for bottleneck in self.bottlenecks:
259+
latent, info_bottleneck = bottleneck(latent, with_info=True)
260+
info = {**info, **prefix_dict("bottleneck_", info_bottleneck)}
261+
return (latent, info) if with_info else latent
262+
263+
def forward( # type: ignore
264+
self, x: Tensor, with_info: bool = False, **kwargs
265+
) -> Union[Tensor, Tuple[Tensor, Any]]:
266+
latent, info = self.encode(x, with_info=True)
267+
loss = self.diffusion(x, channels_list=[latent], **kwargs)
268+
return (loss, info) if with_info else loss
269+
270+
def decode(self, latent: Tensor, **kwargs) -> Tensor:
271+
b = latent.shape[0]
272+
length = closest_power_2(
273+
self.stft.hop_length * latent.shape[2] * self.encoder_downsample_factor
274+
)
275+
# Compute noise by inferring shape from latent length
276+
noise = torch.randn(b, self.in_channels, length, device=latent.device)
277+
# Compute context form latent
278+
default_kwargs = dict(channels_list=[latent])
279+
# Decode by sampling while conditioning on latent channels
280+
return self.sample(noise, **{**default_kwargs, **kwargs}) # type: ignore
281+
282+
def sample(self, *args, **kwargs) -> Tensor:
283+
return self.diffusion.sample(*args, **kwargs)
284+
285+
191286
class DiffusionVocoder1d(Model1d):
192287
def __init__(
193288
self,
@@ -318,6 +413,27 @@ def decode(self, *args, **kwargs):
318413
return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs})
319414

320415

416+
class AudioDiffusionMAE(DiffusionMAE1d):
417+
def __init__(self, *args, **kwargs):
418+
default_kwargs = dict(
419+
patch_blocks=1,
420+
patch_factor=1,
421+
resnet_groups=8,
422+
kernel_multiplier_downsample=2,
423+
use_nearest_upsample=False,
424+
use_skip_scale=True,
425+
use_context_time=True,
426+
diffusion_type="v",
427+
diffusion_sigma_distribution=UniformDistribution(),
428+
stft_num_fft=1023,
429+
stft_hop_length=256,
430+
)
431+
super().__init__(*args, **{**default_kwargs, **kwargs}) # type: ignore
432+
433+
def decode(self, *args, **kwargs):
434+
return super().decode(*args, **{**get_default_sampling_kwargs(), **kwargs})
435+
436+
321437
class AudioDiffusionConditional(Model1d):
322438
def __init__(
323439
self,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name="audio-diffusion-pytorch",
55
packages=find_packages(exclude=[]),
6-
version="0.0.85",
6+
version="0.0.86",
77
license="MIT",
88
description="Audio Diffusion - PyTorch",
99
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)