@@ -368,12 +368,13 @@ def forward_with_cond_scale(
368368 scaled_logits = null_logits + (logits - null_logits ) * cond_scale
369369
370370 if rescaled_phi == 0. :
371- return scaled_logits
371+ return scaled_logits , null_logits
372372
373373 std_fn = partial (torch .std , dim = tuple (range (1 , scaled_logits .ndim )), keepdim = True )
374374 rescaled_logits = scaled_logits * (std_fn (logits ) / std_fn (scaled_logits ))
375+ interpolated_rescaled_logits = rescaled_logits * rescaled_phi + scaled_logits * (1. - rescaled_phi )
375376
376- return rescaled_logits * rescaled_phi + scaled_logits * ( 1. - rescaled_phi )
377+ return interpolated_rescaled_logits , null_logits
377378
378379 def forward (
379380 self ,
@@ -478,7 +479,8 @@ def __init__(
478479 ddim_sampling_eta = 1. ,
479480 offset_noise_strength = 0. ,
480481 min_snr_loss_weight = False ,
481- min_snr_gamma = 5
482+ min_snr_gamma = 5 ,
483+ use_cfg_plus_plus = False # https://arxiv.org/pdf/2406.08070
482484 ):
483485 super ().__init__ ()
484486 assert not (type (self ) == GaussianDiffusion and model .channels != model .out_dim )
@@ -507,6 +509,10 @@ def __init__(
507509 timesteps , = betas .shape
508510 self .num_timesteps = int (timesteps )
509511
512+ # use cfg++ when ddim sampling
513+
514+ self .use_cfg_plus_plus = use_cfg_plus_plus
515+
510516 # sampling related parameters
511517
512518 self .sampling_timesteps = default (sampling_timesteps , timesteps ) # default num sampling timesteps to number of timesteps at training
@@ -604,24 +610,33 @@ def q_posterior(self, x_start, x_t, t):
604610 return posterior_mean , posterior_variance , posterior_log_variance_clipped
605611
606612 def model_predictions (self , x , t , classes , cond_scale = 6. , rescaled_phi = 0.7 , clip_x_start = False ):
607- model_output = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi )
613+ model_output , model_output_null = self .model .forward_with_cond_scale (x , t , classes , cond_scale = cond_scale , rescaled_phi = rescaled_phi )
608614 maybe_clip = partial (torch .clamp , min = - 1. , max = 1. ) if clip_x_start else identity
609615
610616 if self .objective == 'pred_noise' :
611- pred_noise = model_output
617+ pred_noise = model_output if not self .use_cfg_plus_plus else model_output_null
618+
612619 x_start = self .predict_start_from_noise (x , t , pred_noise )
613620 x_start = maybe_clip (x_start )
614621
615622 elif self .objective == 'pred_x0' :
616623 x_start = model_output
617624 x_start = maybe_clip (x_start )
618- pred_noise = self .predict_noise_from_start (x , t , x_start )
625+ x_start_for_pred_noise = x_start if not self .use_cfg_plus_plus else maybe_clip (model_output_null )
626+
627+ pred_noise = self .predict_noise_from_start (x , t , x_start_for_pred_noise )
619628
620629 elif self .objective == 'pred_v' :
621630 v = model_output
622631 x_start = self .predict_start_from_v (x , t , v )
623632 x_start = maybe_clip (x_start )
624- pred_noise = self .predict_noise_from_start (x , t , x_start )
633+
634+ x_start_for_pred_noise = x_start
635+ if self .use_cfg_plus_plus :
636+ x_start_for_pred_noise = self .predict_start_from_v (x , t , model_output_null )
637+ x_start_for_pred_noise = maybe_clip (x_start_for_pred_noise )
638+
639+ pred_noise = self .predict_noise_from_start (x , t , x_start_for_pred_noise )
625640
626641 return ModelPrediction (pred_noise , x_start )
627642
0 commit comments