@@ -71,6 +71,8 @@ class SkipLayerGuidance(GuidanceMixin):
7171 [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
7272 """
7373
74+ _input_predictions = ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
75+
7476 def __init__ (
7577 self ,
7678 guidance_scale : float = 7.5 ,
@@ -82,6 +84,8 @@ def __init__(
8284 guidance_rescale : float = 0.0 ,
8385 use_original_formulation : bool = False ,
8486 ):
87+ super ().__init__ ()
88+
8589 self .guidance_scale = guidance_scale
8690 self .skip_layer_guidance_scale = skip_layer_guidance_scale
8791 self .skip_layer_guidance_start = skip_layer_guidance_start
@@ -157,6 +161,18 @@ def prepare_inputs(self, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]])
157161 )
158162 return tuple (list_of_inputs )
159163
164+ def prepare_outputs (self , pred : torch .Tensor ) -> None :
165+ self ._num_outputs_prepared += 1
166+ if self ._num_outputs_prepared > self .num_conditions :
167+ raise ValueError (f"Expected { self .num_conditions } outputs, but prepare_outputs called more times." )
168+ key = self ._input_predictions [self ._num_outputs_prepared - 1 ]
169+ if not self ._is_cfg_enabled () and self ._is_slg_enabled ():
170+ # If we're predicting pred_cond and pred_cond_skip only, we need to set the key to pred_cond_skip
171+ # to avoid writing into pred_uncond which is not used
172+ if self ._num_outputs_prepared == 2 :
173+ key = "pred_cond_skip"
174+ self ._preds [key ] = pred
175+
160176 def cleanup_models (self , denoiser : torch .nn .Module ):
161177 registry = HookRegistry .check_if_exists_or_initialize (denoiser )
162178 # Remove the hooks after inference
@@ -173,16 +189,16 @@ def forward(
173189 skip_start_step = int (self .skip_layer_guidance_start * self ._num_inference_steps )
174190 skip_stop_step = int (self .skip_layer_guidance_stop * self ._num_inference_steps )
175191
176- if math . isclose ( self .guidance_scale , 1.0 ) and math . isclose ( self .skip_layer_guidance_scale , 1.0 ):
192+ if not self ._is_cfg_enabled ( ) and not self ._is_slg_enabled ( ):
177193 pred = pred_cond
178- elif math . isclose ( self .guidance_scale , 1.0 ):
194+ elif not self ._is_cfg_enabled ( ):
179195 if skip_start_step < self ._step < skip_stop_step :
180196 shift = pred_cond - pred_cond_skip
181197 pred = pred_cond if self .use_original_formulation else pred_cond_skip
182198 pred = pred + self .skip_layer_guidance_scale * shift
183199 else :
184200 pred = pred_cond
185- elif math . isclose ( self .skip_layer_guidance_scale , 1.0 ):
201+ elif not self ._is_slg_enabled ( ):
186202 shift = pred_cond - pred_uncond
187203 pred = pred_cond if self .use_original_formulation else pred_uncond
188204 pred = pred + self .guidance_scale * shift
@@ -203,12 +219,19 @@ def forward(
203219 @property
204220 def num_conditions (self ) -> int :
205221 num_conditions = 1
206- skip_start_step = int (self .skip_layer_guidance_start * self ._num_inference_steps )
207- skip_stop_step = int (self .skip_layer_guidance_stop * self ._num_inference_steps )
208-
209- if not math .isclose (self .guidance_scale , 1.0 ):
222+ if self ._is_cfg_enabled ():
210223 num_conditions += 1
211- if not math . isclose ( self .skip_layer_guidance_scale , 1.0 ) and skip_start_step < self . _step < skip_stop_step :
224+ if self ._is_slg_enabled () :
212225 num_conditions += 1
213-
214226 return num_conditions
227+
228+ def _is_cfg_enabled (self ) -> bool :
229+ if self .use_original_formulation :
230+ return not math .isclose (self .guidance_scale , 0.0 )
231+ else :
232+ return not math .isclose (self .guidance_scale , 1.0 )
233+
234+ def _is_slg_enabled (self ) -> bool :
235+ skip_start_step = int (self .skip_layer_guidance_start * self ._num_inference_steps )
236+ skip_stop_step = int (self .skip_layer_guidance_stop * self ._num_inference_steps )
237+ return skip_start_step < self ._step < skip_stop_step
0 commit comments