Skip to content

Commit dcc3da8

Browse files
committed
able to customize attention heads and dimension per head
1 parent 596aa34 commit dcc3da8

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ def __init__(
263263
learned_sinusoidal_cond = False,
264264
random_fourier_features = False,
265265
learned_sinusoidal_dim = 16,
266+
attn_dim_head = 32,
267+
attn_heads = 4
266268
):
267269
super().__init__()
268270

@@ -334,7 +336,7 @@ def __init__(
334336

335337
mid_dim = dims[-1]
336338
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
337-
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
339+
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads)))
338340
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim, classes_emb_dim = classes_dim)
339341

340342
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ def __init__(
271271
learned_sinusoidal_cond = False,
272272
random_fourier_features = False,
273273
learned_sinusoidal_dim = 16,
274+
attn_dim_head = 32,
275+
attn_heads = 4,
274276
full_attn = (False, False, False, True),
275277
flash_attn = False
276278
):
@@ -331,7 +333,7 @@ def __init__(
331333
self.downs.append(nn.ModuleList([
332334
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
333335
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
334-
attn_klass(dim_in),
336+
attn_klass(dim_in, dim_head = attn_dim_head, heads = attn_heads),
335337
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
336338
]))
337339

@@ -348,7 +350,7 @@ def __init__(
348350
self.ups.append(nn.ModuleList([
349351
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
350352
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
351-
attn_klass(dim_out),
353+
attn_klass(dim_out, dim_head = attn_dim_head, heads = attn_heads),
352354
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
353355
]))
354356

denoising_diffusion_pytorch/denoising_diffusion_pytorch_1d.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def __init__(
265265
learned_variance = False,
266266
learned_sinusoidal_cond = False,
267267
random_fourier_features = False,
268-
learned_sinusoidal_dim = 16
268+
learned_sinusoidal_dim = 16,
269+
attn_dim_head = 32,
270+
attn_heads = 4
269271
):
270272
super().__init__()
271273

@@ -321,7 +323,7 @@ def __init__(
321323

322324
mid_dim = dims[-1]
323325
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
324-
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
326+
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, dim_head = attn_dim_head, heads = attn_heads)))
325327
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
326328

327329
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.8.5'
1+
__version__ = '1.8.6'

0 commit comments

Comments
 (0)