@@ -26,38 +26,62 @@ from audio_diffusion_pytorch import AudioDiffusionModel
2626model = AudioDiffusionModel(in_channels = 1 )
2727
2828# Train model with audio sources
29- x = torch.randn(2 , 1 , 2 ** 18 ) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050Hz
29+ x = torch.randn(2 , 1 , 2 ** 18 ) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050
3030loss = model(x)
3131loss.backward() # Do this many times
3232
3333# Sample 2 sources given start noise
3434noise = torch.randn(2 , 1 , 2 ** 18 )
3535sampled = model.sample(
3636 noise = noise,
37- num_steps = 5 # Suggested range: 2-100
38- ) # [2, 1, 262144 ]
37+ num_steps = 5 # Suggested range: 2-50
38+ ) # [2, 1, 2 ** 18 ]
3939```
4040
4141### Upsampling
4242``` py
4343from audio_diffusion_pytorch import AudioDiffusionUpsampler
4444
4545upsampler = AudioDiffusionUpsampler(
46- factor = 4 ,
47- in_channels = 1
46+ in_channels = 1 ,
47+ factor = 8 ,
4848)
4949
5050# Train on high frequency data
51- x = torch.randn(2 , 1 , 2 ** 18 ) # [batch, in_channels, samples]
51+ x = torch.randn(2 , 1 , 2 ** 18 )
5252loss = upsampler(x)
5353loss.backward()
5454
5555# Given start undersampled source, samples upsampled source
56- start = torch.randn(1 , 1 , 2 ** 16 )
57- sampled = upsampler.sample(
58- start = start,
59- num_steps = 5 # Suggested range: 2-100
56+ undersampled = torch.randn(1 , 1 , 2 ** 15 )
57+ upsampled = upsampler.sample(
58+ undersampled,
59+ num_steps = 5
60+ ) # [1, 1, 2 ** 18]
61+ ```
62+
63+ ### Autoencoding
64+ ``` py
65+ autoencoder = AudioDiffusionAutoencoder(
66+ in_channels = 1 ,
67+ encoder_depth = 4 ,
68+ encoder_channels = 32
6069)
70+
71+ # Train on audio samples
72+ x = torch.randn(2 , 1 , 2 ** 18 )
73+ loss = autoencoder(x)
74+ loss.backward()
75+
76+ # Encode audio source into latent
77+ x = torch.randn(2 , 1 , 2 ** 18 )
78+ latent = autoencoder.encode(x) # [2, 32, 128]
79+
80+ # Decode latent by diffusion sampling
81+ decoded = autoencoder.decode(
82+ latent,
83+ num_steps = 5
84+ ) # [2, 32, 2**18]
6185```
6286
6387## Usage with Components
0 commit comments