@@ -584,8 +584,6 @@ def step(
584584 s_noise : float = 1.0 ,
585585 generator : Optional [torch .Generator ] = None ,
586586 return_dict : bool = True ,
587- _model_output_uncond : Optional [torch .Tensor ] = None ,
588- _use_cfgpp : bool = False ,
589587 ) -> Union [EulerDiscreteSchedulerOutput , Tuple ]:
590588 """
591589 Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
@@ -629,11 +627,6 @@ def step(
629627 "The `scale_model_input` function should be called before `step` to ensure correct denoising. "
630628 "See `StableDiffusionPipeline` for a usage example."
631629 )
632-
633- if _use_cfgpp and self .config .prediction_type != "epsilon" :
634- raise ValueError (
635- f"CFG++ is only supported for prediction type `epsilon`, but got { self .config .prediction_type } ."
636- )
637630
638631 if self .step_index is None :
639632 self ._init_step_index (timestep )
@@ -675,38 +668,6 @@ def step(
675668 dt = self .sigmas [self .step_index + 1 ] - sigma_hat
676669
677670 prev_sample = sample + derivative * dt
678- if _use_cfgpp :
679- prev_sample = prev_sample + (_model_output_uncond - model_output ) * self .sigmas [self .step_index + 1 ]
680-
681- # denoised = sample - model_output * sigmas[i]
682- # d = (sample - denoised) / sigmas[i]
683- # new_sample = denoised + d * sigmas[i + 1]
684-
685- # new_sample = denoised + (sample - denoised) * sigmas[i + 1] / sigmas[i]
686- # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1]
687- # new_sample = sample + model_output * (sigmas[i + 1] - sigmas[i])
688- # new_sample = sample - model_output * sigmas[i] + model_output * sigmas[i + 1] --- (1)
689-
690- # CFG++ =====
691- # denoised = sample - model_output * sigmas[i]
692- # uncond_denoised = sample - model_output_uncond * sigmas[i]
693- # d = (sample - uncond_denoised) / sigmas[i]
694- # new_sample = denoised + d * sigmas[i + 1]
695-
696- # new_sample = denoised + (sample - uncond_denoised) * sigmas[i + 1] / sigmas[i]
697- # new_sample = sample - model_output * sigmas[i] + model_output_uncond * sigmas[i + 1] --- (2)
698-
699- # To go from (1) to (2):
700- # new_sample_2 = new_sample_1 - model_output * sigmas[i + 1] + model_output_uncond * sigmas[i + 1]
701- # new_sample_2 = new_sample_1 + (model_output_uncond - model_output) * sigmas[i + 1]
702- # new_sample_2 = new_sample_1 + diff * sigmas[i + 1]
703-
704- # diff = model_output_uncond - model_output
705- # diff = model_output_uncond - (model_output_uncond + g * (model_output_cond - model_output_uncond))
706- # diff = model_output_uncond - (g * model_output_cond + (1 - g) * model_output_uncond)
707- # diff = model_output_uncond - g * model_output_cond + (g - 1) * model_output_uncond
708- # diff = g * (model_output_uncond - model_output_cond)
709-
710671 # Cast sample back to model compatible dtype
711672 prev_sample = prev_sample .to (model_output .dtype )
712673
0 commit comments