@@ -493,6 +493,7 @@ def __init__(
493493 timesteps = 1000 ,
494494 use_ddim = False ,
495495 noise_schedule = 'sigmoid' ,
496+ objective = 'eps' ,
496497 schedule_kwargs : dict = dict (),
497498 time_difference = 0. ,
498499 train_prob_self_cond = 0.9
@@ -501,6 +502,9 @@ def __init__(
501502 self .model = model
502503 self .channels = self .model .channels
503504
505+ assert objective in {'x0' , 'eps' }, 'objective must be either predict x0 or noise'
506+ self .objective = objective
507+
504508 self .image_size = image_size
505509
506510 if noise_schedule == "linear" :
@@ -560,11 +564,7 @@ def ddpm_sample(self, shape, time_difference = None):
560564
561565 # get predicted x0
562566
563- x_start , last_latents = self .model (img , noise_cond , x_start , last_latents , return_latents = True )
564-
565- # clip x0
566-
567- x_start .clamp_ (- 1. , 1. )
567+ model_output , last_latents = self .model (img , noise_cond , x_start , last_latents , return_latents = True )
568568
569569 # get log(snr)
570570
@@ -577,6 +577,18 @@ def ddpm_sample(self, shape, time_difference = None):
577577 alpha , sigma = log_snr_to_alpha_sigma (log_snr )
578578 alpha_next , sigma_next = log_snr_to_alpha_sigma (log_snr_next )
579579
580+ # calculate x0 and noise
581+
582+ if self .objective == 'x0' :
583+ x_start = model_output
584+
585+ elif self .objective == 'eps' :
586+ x_start = (img - sigma * model_output ) / alpha
587+
588+ # clip x0
589+
590+ x_start .clamp_ (- 1. , 1. )
591+
580592 # derive posterior mean and variance
581593
582594 c = - expm1 (log_snr - log_snr_next )
@@ -628,15 +640,27 @@ def ddim_sample(self, shape, time_difference = None):
628640
629641 # predict x0
630642
631- x_start , last_latents = self .model (img , log_snr , x_start , last_latents , return_latents = True )
643+ model_output , last_latents = self .model (img , log_snr , x_start , last_latents , return_latents = True )
644+
645+ # calculate x0 and noise
646+
647+ if self .objective == 'x0' :
648+ x_start = model_output
649+
650+ elif self .objective == 'eps' :
651+ x_start = (img - sigma * model_output ) / alpha
632652
633653 # clip x0
634654
635655 x_start .clamp_ (- 1. , 1. )
636656
637657 # get predicted noise
638658
639- pred_noise = (img - alpha * x_start ) / sigma .clamp (min = 1e-8 )
659+ if self .objective == 'x0' :
660+ pred_noise = (img - alpha * x_start ) / sigma .clamp (min = 1e-8 )
661+
662+ elif self .objective == 'eps' :
663+ pred_noise = model_output
640664
641665 # calculate x next
642666
@@ -687,7 +711,13 @@ def forward(self, img, *args, **kwargs):
687711
688712 pred = self .model (noised_img , noise_level , self_cond , self_latents )
689713
690- return F .mse_loss (pred , img )
714+ if self .objective == 'x0' :
715+ target = img
716+
717+ elif self .objective == 'eps' :
718+ target = noise
719+
720+ return F .mse_loss (pred , target )
691721
692722# dataset classes
693723
0 commit comments