Skip to content

Commit 8cc1db6

Browse files
committed
fix min snr logic, thanks to @justinlovelace
1 parent 77e4ced commit 8cc1db6

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def forward(self, img, *args, **kwargs):
836836
maybe_clipped_snr = snr.clone()
837837

838838
if self.min_snr_loss_weight:
839-
maybe_clipped_snr.clamp_(min = self.min_snr_gamma)
839+
maybe_clipped_snr.clamp_(max = self.min_snr_gamma)
840840

841841
if self.objective == 'eps':
842842
loss_weight = maybe_clipped_snr / snr

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.1',
6+
version = '0.7.2',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)