Skip to content

Commit 51e3840

Browse files
committed
fix an issue with karras unet when no class labels, and make sure it is compatible with GaussianDiffusion wrapper
1 parent 3b78964 commit 51e3840

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __init__(
489489
):
490490
super().__init__()
491491
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
492-
assert not model.random_or_learned_sinusoidal_cond
492+
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
493493

494494
self.model = model
495495

denoising_diffusion_pytorch/karras_unet.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
the magnitude-preserving unet proposed in https://arxiv.org/abs/2312.02696 by Karras et al.
3+
"""
4+
15
import math
26
from math import sqrt, ceil
37
from functools import partial
@@ -456,8 +460,9 @@ def __init__(
456460
self.needs_class_labels = exists(num_classes)
457461
self.num_classes = num_classes
458462

459-
self.to_class_emb = Linear(num_classes, 4 * dim)
460-
self.add_class_emb = MPAdd(t = mp_add_emb_t)
463+
if self.needs_class_labels:
464+
self.to_class_emb = Linear(num_classes, 4 * dim)
465+
self.add_class_emb = MPAdd(t = mp_add_emb_t)
461466

462467
# final embedding activations
463468

@@ -537,6 +542,8 @@ def __init__(
537542
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
538543
])
539544

545+
self.out_dim = channels
546+
540547
@property
541548
def downsample_factor(self):
542549
return 2 ** self.num_downsamples
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.1'
1+
__version__ = '1.10.2'

0 commit comments

Comments
 (0)