@@ -568,6 +568,8 @@ def __init__(
568568 objective = 'v' ,
569569 schedule_kwargs : dict = dict (),
570570 time_difference = 0. ,
571+ min_snr_loss_weight = True ,
572+ min_snr_gamma = 5 ,
571573 train_prob_self_cond = 0.9 ,
572574 scale = 1. # this will be set to < 1. for better convergence when training on higher resolution images
573575 ):
@@ -611,6 +613,11 @@ def __init__(
611613
612614 self .train_prob_self_cond = train_prob_self_cond
613615
616+ # min snr loss weight
617+
618+ self .min_snr_loss_weight = min_snr_loss_weight
619+ self .min_snr_gamma = min_snr_gamma
620+
614621 @property
615622 def device (self ):
616623 return next (self .model .parameters ()).device
@@ -811,16 +818,36 @@ def forward(self, img, *args, **kwargs):
811818
812819 pred = self .model (noised_img , times , self_cond , self_latents )
813820
814- if self .objective == 'x0' :
815- target = img
816-
817- elif self .objective == 'eps' :
821+ if self .objective == 'eps' :
818822 target = noise
819823
824+ elif self .objective == 'x0' :
825+ target = img
826+
820827 elif self .objective == 'v' :
821828 target = alpha * noise - sigma * img
822829
823- return F .mse_loss (pred , target )
830+ loss = F .mse_loss (pred , target , reduction = 'none' )
831+ loss = reduce (loss , 'b ... -> b' , 'mean' )
832+
833+ # min snr loss weight
834+
835+ snr = (alpha * alpha ) / (sigma * sigma )
836+ maybe_clipped_snr = snr .clone ()
837+
838+ if self .min_snr_loss_weight :
839+ maybe_clipped_snr .clamp_ (min = self .min_snr_gamma )
840+
841+ if self .objective == 'eps' :
842+ loss_weight = maybe_clipped_snr / snr
843+
844+ elif self .objective == 'x0' :
845+ loss_weight = maybe_clipped_snr
846+
847+ elif self .objective == 'v' :
848+ loss_weight = maybe_clipped_snr / (snr + 1 )
849+
850+ return (loss * loss_weight ).mean ()
824851
825852# dataset classes
826853
@@ -872,7 +899,7 @@ def __init__(
872899 train_num_steps = 100000 ,
873900 ema_update_every = 10 ,
874901 ema_decay = 0.995 ,
875- adam_betas = (0.9 , 0.99 ),
902+ betas = (0.9 , 0.99 ),
876903 save_and_sample_every = 1000 ,
877904 num_samples = 25 ,
878905 results_folder = './results' ,
@@ -912,7 +939,7 @@ def __init__(
912939
913940 # optimizer
914941
915- self .opt = Adam (diffusion_model .parameters (), lr = train_lr , betas = adam_betas )
942+ self .opt = Adam (diffusion_model .parameters (), lr = train_lr , betas = betas )
916943
917944 # for logging results in a folder periodically
918945
0 commit comments