Skip to content

Commit 8ff6835

Browse files
committed
at least two attention heads in karras unet
1 parent 51e3840 commit 8ff6835

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

denoising_diffusion_pytorch/karras_unet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def __init__(
240240
if has_attn:
241241
self.attn = Attention(
242242
dim = dim_out,
243-
heads = ceil(dim_out / attn_dim_head),
243+
heads = max(ceil(dim_out / attn_dim_head), 2),
244244
dim_head = attn_dim_head,
245245
mp_add_t = attn_res_mp_add_t,
246246
flash = attn_flash
@@ -322,7 +322,7 @@ def __init__(
322322
if has_attn:
323323
self.attn = Attention(
324324
dim = dim_out,
325-
heads = ceil(dim_out / attn_dim_head),
325+
heads = max(ceil(dim_out / attn_dim_head), 2),
326326
dim_head = attn_dim_head,
327327
mp_add_t = attn_res_mp_add_t,
328328
flash = attn_flash
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.2'
1+
__version__ = '1.10.3'

0 commit comments

Comments
 (0)