Skip to content

Commit 7558d3f

Browse files
committed
small convenience for setting which layers get full attention
1 parent 9e43418 commit 7558d3f

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def __init__(
286286
sinusoidal_pos_emb_theta = 10000,
287287
attn_dim_head = 32,
288288
attn_heads = 4,
289-
full_attn = None,#default(F, F, F, T)
289+
full_attn = None, # defaults to full attention only for inner most layer
290290
flash_attn = False
291291
):
292292
super().__init__()
@@ -328,7 +328,8 @@ def __init__(
328328
# attention
329329

330330
if not full_attn:
331-
full_attn = tuple([False] * (len(dim_mults)-1) + [True])
331+
full_attn = (*((False,) * (len(dim_mults) - 1)), True)
332+
332333
num_stages = len(dim_mults)
333334
full_attn = cast_tuple(full_attn, num_stages)
334335
attn_heads = cast_tuple(attn_heads, num_stages)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.9.0'
1+
__version__ = '1.9.1'

0 commit comments

Comments
 (0)