@@ -727,19 +727,24 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):
727727 def q_sample (self , x_start , t , noise = None ):
728728 noise = default (noise , lambda : torch .randn_like (x_start ))
729729
730- if self .offset_noise_strength > 0. :
731- offset_noise = torch .randn (x_start .shape [:2 ], device = self .device )
732- noise += self .offset_noise_strength * rearrange (offset_noise , 'b c -> b c 1 1' )
733-
734730 return (
735731 extract (self .sqrt_alphas_cumprod , t , x_start .shape ) * x_start +
736732 extract (self .sqrt_one_minus_alphas_cumprod , t , x_start .shape ) * noise
737733 )
738734
739- def p_losses (self , x_start , t , noise = None ):
735+ def p_losses (self , x_start , t , noise = None , offset_noise_strength = None ):
740736 b , c , h , w = x_start .shape
737+
741738 noise = default (noise , lambda : torch .randn_like (x_start ))
742739
740+ # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
741+
742+ offset_noise_strength = default (offset_noise_strength , self .offset_noise_strength )
743+
744+ if offset_noise_strength > 0. :
745+ offset_noise = torch .randn (x_start .shape [:2 ], device = self .device )
746+ noise += offset_noise_strength * rearrange (offset_noise , 'b c -> b c 1 1' )
747+
743748 # noise sample
744749
745750 x = self .q_sample (x_start = x_start , t = t , noise = noise )
0 commit comments