Skip to content

Commit 68da808

Browse files
feat: add v-diffusion class/distribution, refactor diffusion
1 parent 125b938 commit 68da808

File tree

5 files changed

+241
-121
lines changed

5 files changed

+241
-121
lines changed

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ unet = UNet1d(
139139
use_nearest_upsample=False,
140140
use_skip_scale=True,
141141
use_context_time=True,
142+
use_magnitude_channels=False
142143
)
143144

144145
x = torch.randn(3, 1, 2 ** 16)
@@ -151,13 +152,20 @@ y = unet(x, t) # [3, 1, 32768], compute 3 samples of ~1.5 seconds at 22050Hz wit
151152

152153
#### Training
153154
```python
154-
from audio_diffusion_pytorch import Diffusion, LogNormalDistribution
155+
from audio_diffusion_pytorch import KDiffusion, VDiffusion, LogNormalDistribution, VDistribution
155156

156-
diffusion = Diffusion(
157+
# Either use KDiffusion
158+
diffusion = KDiffusion(
157159
net=unet,
158160
sigma_distribution=LogNormalDistribution(mean = -3.0, std = 1.0),
159161
sigma_data=0.1,
160-
dynamic_threshold=0.95
162+
dynamic_threshold=0.0
163+
)
164+
165+
# Or use VDiffusion
166+
diffusion = VDiffusion(
167+
net=unet,
168+
sigma_distribution=VDistribution()
161169
)
162170

163171
x = torch.randn(3, 1, 2 ** 18) # Batch of training audio samples
@@ -239,6 +247,7 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
239247
- [x] Add conditional model with classifier-free guidance.
240248
- [x] Add option to provide context features mapping.
241249
- [x] Add option to change number of (cross) attention blocks.
250+
- [x] Add `VDiffusionn` option.
242251
- [ ] Add flash attention.
243252

244253

audio_diffusion_pytorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
Distribution,
88
KarrasSampler,
99
KarrasSchedule,
10+
KDiffusion,
1011
LogNormalDistribution,
1112
Sampler,
1213
Schedule,
1314
SpanBySpanComposer,
15+
VDiffusion,
16+
VDistribution,
1417
)
1518
from .model import (
1619
AudioDiffusionAutoencoder,

0 commit comments

Comments
 (0)