@@ -498,8 +498,8 @@ def sigmoid_schedule(t, start = -3, end = 3, tau = 1, clamp_min = 1e-9):
498498
499499# converting gamma to alpha, sigma or logsnr
500500
501- def gamma_to_alpha_sigma (gamma ):
502- return torch .sqrt (gamma ), torch .sqrt (1 - gamma )
501+ def gamma_to_alpha_sigma (gamma , scale = 1 ):
502+ return torch .sqrt (gamma ) * scale , torch .sqrt (1 - gamma )
503503
504504def gamma_to_log_snr (gamma , eps = 1e-5 ):
505505 return - log (gamma ** - 1. - 1 , eps = eps )
@@ -543,7 +543,7 @@ def __init__(
543543 # the main finding presented in Ting Chen's paper - that higher resolution images requires more noise for better training
544544
545545 assert scale <= 1 , 'scale must be less than or equal to 1'
546- self .scale = scale #
546+ self .scale = scale
547547 self .normalize_img_variance = normalize_img_variance if scale < 1 else identity
548548
549549 # gamma schedules
@@ -607,8 +607,8 @@ def ddpm_sample(self, shape, time_difference = None):
607607
608608 # get alpha sigma of time and next time
609609
610- alpha , sigma = gamma_to_alpha_sigma (gamma )
611- alpha_next , sigma_next = gamma_to_alpha_sigma (gamma_next )
610+ alpha , sigma = gamma_to_alpha_sigma (gamma , self . scale )
611+ alpha_next , sigma_next = gamma_to_alpha_sigma (gamma_next , self . scale )
612612
613613 # calculate x0 and noise
614614
@@ -666,8 +666,8 @@ def ddim_sample(self, shape, time_difference = None):
666666
667667 padded_gamma , padded_gamma_next = map (partial (right_pad_dims_to , img ), (gamma , gamma_next ))
668668
669- alpha , sigma = gamma_to_alpha_sigma (padded_gamma )
670- alpha_next , sigma_next = gamma_to_alpha_sigma (padded_gamma_next )
669+ alpha , sigma = gamma_to_alpha_sigma (padded_gamma , self . scale )
670+ alpha_next , sigma_next = gamma_to_alpha_sigma (padded_gamma_next , self . scale )
671671
672672 # add the time delay
673673
@@ -728,10 +728,12 @@ def forward(self, img, *args, **kwargs):
728728
729729 gamma = self .gamma_schedule (times )
730730 padded_gamma = right_pad_dims_to (img , gamma )
731- alpha , sigma = gamma_to_alpha_sigma (padded_gamma )
731+ alpha , sigma = gamma_to_alpha_sigma (padded_gamma , self . scale )
732732
733733 noised_img = alpha * img + sigma * noise
734734
735+ noised_img = self .normalize_img_variance (noised_img )
736+
735737 # in the paper, they had to use a really high probability of latent self conditioning, up to 90% of the time
736738 # slight drawback
737739
@@ -745,7 +747,6 @@ def forward(self, img, *args, **kwargs):
745747
746748 # predict and take gradient step
747749
748- noised_img = self .normalize_img_variance (noised_img )
749750 pred = self .model (noised_img , times , self_cond , self_latents )
750751
751752 if self .objective == 'x0' :
0 commit comments