Skip to content

Commit f363854

Browse files
committed
fix an issue with karras unet when no class labels, and make sure it is compatible with GaussianDiffusion wrapper
1 parent 3b78964 commit f363854

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __init__(
489489
):
490490
super().__init__()
491491
assert not (type(self) == GaussianDiffusion and model.channels != model.out_dim)
492-
assert not model.random_or_learned_sinusoidal_cond
492+
assert not hasattr(model, 'random_or_learned_sinusoidal_cond') or not model.random_or_learned_sinusoidal_cond
493493

494494
self.model = model
495495

denoising_diffusion_pytorch/karras_unet.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,9 @@ def __init__(
456456
self.needs_class_labels = exists(num_classes)
457457
self.num_classes = num_classes
458458

459-
self.to_class_emb = Linear(num_classes, 4 * dim)
460-
self.add_class_emb = MPAdd(t = mp_add_emb_t)
459+
if self.needs_class_labels:
460+
self.to_class_emb = Linear(num_classes, 4 * dim)
461+
self.add_class_emb = MPAdd(t = mp_add_emb_t)
461462

462463
# final embedding activations
463464

@@ -537,6 +538,8 @@ def __init__(
537538
Decoder(curr_dim, curr_dim, has_attn = mid_has_attn, **block_kwargs),
538539
])
539540

541+
self.out_dim = channels
542+
540543
@property
541544
def downsample_factor(self):
542545
return 2 ** self.num_downsamples
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.10.1'
1+
__version__ = '1.10.2'

0 commit comments

Comments
 (0)