Skip to content

Commit b8f8e40

Browse files
committed
fix min snr logic, thanks to @justinlovelace
1 parent e6f2d01 commit b8f8e40

File tree

5 files changed

+5
-5
lines changed

5 files changed

+5
-5
lines changed

denoising_diffusion_pytorch/classifier_free_guidance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def __init__(
562562

563563
maybe_clipped_snr = snr.clone()
564564
if min_snr_loss_weight:
565-
maybe_clipped_snr.clamp_(min = min_snr_gamma)
565+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
566566

567567
if objective == 'pred_noise':
568568
loss_weight = maybe_clipped_snr / snr

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def __init__(
538538

539539
maybe_clipped_snr = snr.clone()
540540
if min_snr_loss_weight:
541-
maybe_clipped_snr.clamp_(min = min_snr_gamma)
541+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
542542

543543
if objective == 'pred_noise':
544544
register_buffer('loss_weight', maybe_clipped_snr / snr)

denoising_diffusion_pytorch/guided_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def __init__(
532532

533533
maybe_clipped_snr = snr.clone()
534534
if min_snr_loss_weight:
535-
maybe_clipped_snr.clamp_(min = min_snr_gamma)
535+
maybe_clipped_snr.clamp_(max = min_snr_gamma)
536536

537537
if objective == 'pred_noise':
538538
loss_weight = maybe_clipped_snr / snr

denoising_diffusion_pytorch/simple_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,7 @@ def p_losses(self, x_start, times, noise = None):
690690

691691
maybe_clip_snr = snr.clone()
692692
if self.min_snr_loss_weight:
693-
maybe_clip_snr.clamp_(min = self.min_snr_gamma)
693+
maybe_clip_snr.clamp_(max = self.min_snr_gamma)
694694

695695
if self.pred_objective == 'v':
696696
loss_weight = maybe_clip_snr / (snr + 1)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.5.2'
1+
__version__ = '1.5.3'

0 commit comments

Comments
 (0)