Skip to content

Commit d60eefa

Browse files
feat: update cross attention, option to change attention blocks, default to 16 patch factor
1 parent 9f57a87 commit d60eefa

File tree

4 files changed

+159
-236
lines changed

4 files changed

+159
-236
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,16 @@ from audio_diffusion_pytorch import UNet1d
125125
unet = UNet1d(
126126
in_channels=1,
127127
channels=128,
128-
patch_blocks=4,
129-
patch_factor=2,
128+
patch_blocks=16,
129+
patch_factor=1,
130130
kernel_sizes_init=[1, 3, 7],
131131
multipliers=[1, 2, 4, 4, 4, 4, 4],
132132
factors=[4, 4, 4, 2, 2, 2],
133-
attentions=[False, False, False, True, True, True],
133+
attentions=[0, 0, 0, 1, 1, 1, 1],
134134
num_blocks=[2, 2, 2, 2, 2, 2],
135135
attention_heads=8,
136136
attention_features=64,
137137
attention_multiplier=2,
138-
use_attention_bottleneck=True,
139138
resnet_groups=8,
140139
kernel_multiplier_downsample=2,
141140
use_nearest_upsample=False,
@@ -229,16 +228,20 @@ y_long = composer(y, keep_start=True) # [1, 1, 98304]
229228
- [x] Add elucidated diffusion.
230229
- [x] Add ancestral DPM2 sampler.
231230
- [x] Add dynamic thresholding.
232-
- [x] Add (variational) autoencoder option to compress audio before diffusion.
231+
- [x] Add (variational) autoencoder option to compress audio before diffusion (removed).
233232
- [x] Fix inpainting and make it work with ADPM2 sampler.
234233
- [x] Add trainer with experiments.
235234
- [x] Add diffusion upsampler.
236235
- [x] Add ancestral euler sampler `AEulerSampler`.
237236
- [x] Add diffusion autoencoder.
237+
- [x] Add diffusion upsampler.
238238
- [x] Add autoencoder bottleneck option for quantization.
239-
- [x] Add option to provide context tokens (resnet cross attention).
239+
- [x] Add option to provide context tokens (cross attention).
240240
- [x] Add conditional model with classifier-free guidance.
241241
- [x] Add option to provide context features mapping.
242+
- [x] Add option to change number of (cross) attention blocks.
243+
- [ ] Add flash attention.
244+
242245

243246
## Appreciation
244247

audio_diffusion_pytorch/model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,17 +202,16 @@ def decode(self, latent: Tensor, **kwargs) -> Tensor:
202202
def get_default_model_kwargs():
203203
return dict(
204204
channels=128,
205-
patch_blocks=4,
206-
patch_factor=2,
205+
patch_blocks=1,
206+
patch_factor=16,
207207
kernel_sizes_init=[1, 3, 7],
208208
multipliers=[1, 2, 4, 4, 4, 4, 4],
209209
factors=[4, 4, 4, 2, 2, 2],
210210
num_blocks=[2, 2, 2, 2, 2, 2],
211-
attentions=[False, False, False, True, True, True],
211+
attentions=[0, 0, 0, 1, 1, 1, 1],
212212
attention_heads=8,
213213
attention_features=64,
214214
attention_multiplier=2,
215-
use_attention_bottleneck=True,
216215
resnet_groups=8,
217216
kernel_multiplier_downsample=2,
218217
use_nearest_upsample=False,

0 commit comments

Comments
 (0)