Skip to content

Commit adc9b5e

Browse files
feat: add diffusion autoencoder example
1 parent 8b2ccc7 commit adc9b5e

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

README.md

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,38 +26,62 @@ from audio_diffusion_pytorch import AudioDiffusionModel
2626
model = AudioDiffusionModel(in_channels=1)
2727

2828
# Train model with audio sources
29-
x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050Hz
29+
x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples], 2**18 ≈ 12s of audio at a frequency of 22050
3030
loss = model(x)
3131
loss.backward() # Do this many times
3232

3333
# Sample 2 sources given start noise
3434
noise = torch.randn(2, 1, 2 ** 18)
3535
sampled = model.sample(
3636
noise=noise,
37-
num_steps=5 # Suggested range: 2-100
38-
) # [2, 1, 262144]
37+
num_steps=5 # Suggested range: 2-50
38+
) # [2, 1, 2 ** 18]
3939
```
4040

4141
### Upsampling
4242
```py
4343
from audio_diffusion_pytorch import AudioDiffusionUpsampler
4444

4545
upsampler = AudioDiffusionUpsampler(
46-
factor=4,
47-
in_channels=1
46+
in_channels=1,
47+
factor=8,
4848
)
4949

5050
# Train on high frequency data
51-
x = torch.randn(2, 1, 2 ** 18) # [batch, in_channels, samples]
51+
x = torch.randn(2, 1, 2 ** 18)
5252
loss = upsampler(x)
5353
loss.backward()
5454

5555
# Given start undersampled source, samples upsampled source
56-
start = torch.randn(1, 1, 2 ** 16)
57-
sampled = upsampler.sample(
58-
start=start,
59-
num_steps=5 # Suggested range: 2-100
56+
undersampled = torch.randn(1, 1, 2 ** 15)
57+
upsampled = upsampler.sample(
58+
undersampled,
59+
num_steps=5
60+
) # [1, 1, 2 ** 18]
61+
```
62+
63+
### Autoencoding
64+
```py
65+
autoencoder = AudioDiffusionAutoencoder(
66+
in_channels=1,
67+
encoder_depth=4,
68+
encoder_channels=32
6069
)
70+
71+
# Train on audio samples
72+
x = torch.randn(2, 1, 2 ** 18)
73+
loss = autoencoder(x)
74+
loss.backward()
75+
76+
# Encode audio source into latent
77+
x = torch.randn(2, 1, 2 ** 18)
78+
latent = autoencoder.encode(x) # [2, 32, 128]
79+
80+
# Decode latent by diffusion sampling
81+
decoded = autoencoder.decode(
82+
latent,
83+
num_steps=5
84+
) # [2, 32, 2**18]
6185
```
6286

6387
## Usage with Components

0 commit comments

Comments
 (0)