Skip to content

Commit 995f1c5

Browse files
committed
move offset noise to only during training, and make the strength overrideable
1 parent 6219bf7 commit 995f1c5

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

denoising_diffusion_pytorch/denoising_diffusion_pytorch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -727,19 +727,24 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
727727
def q_sample(self, x_start, t, noise = None):
728728
noise = default(noise, lambda: torch.randn_like(x_start))
729729

730-
if self.offset_noise_strength > 0.:
731-
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
732-
noise += self.offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
733-
734730
return (
735731
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
736732
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
737733
)
738734

739-
def p_losses(self, x_start, t, noise = None):
735+
def p_losses(self, x_start, t, noise = None, offset_noise_strength = None):
740736
b, c, h, w = x_start.shape
737+
741738
noise = default(noise, lambda: torch.randn_like(x_start))
742739

740+
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
741+
742+
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
743+
744+
if offset_noise_strength > 0.:
745+
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
746+
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
747+
743748
# noise sample
744749

745750
x = self.q_sample(x_start = x_start, t = t, noise = noise)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.6.3'
1+
__version__ = '1.6.4'

0 commit comments

Comments
 (0)