1313# limitations under the License.
1414
1515import math
16- from typing import List , Optional , Union , TYPE_CHECKING
16+ from typing import TYPE_CHECKING , List , Optional , Union
1717
1818import torch
1919
2020from ..hooks import HookRegistry , LayerSkipConfig
2121from ..hooks .layer_skip import _apply_layer_skip_hook
2222from .guider_utils import BaseGuidance , rescale_noise_cfg
2323
24+
2425if TYPE_CHECKING :
2526 from ..pipelines .modular_pipeline import BlockState
2627
@@ -113,13 +114,13 @@ def prepare_models(self, denoiser: torch.nn.Module) -> None:
113114 if self ._is_ag_enabled () and self .is_unconditional :
114115 for name , config in zip (self ._auto_guidance_hook_names , self .auto_guidance_config ):
115116 _apply_layer_skip_hook (denoiser , config , name = name )
116-
117+
117118 def cleanup_models (self , denoiser : torch .nn .Module ) -> None :
118119 if self ._is_ag_enabled () and self .is_unconditional :
119120 for name in self ._auto_guidance_hook_names :
120121 registry = HookRegistry .check_if_exists_or_initialize (denoiser )
121122 registry .remove_hook (name , recurse = True )
122-
123+
123124 def prepare_inputs (self , data : "BlockState" ) -> List ["BlockState" ]:
124125 tuple_indices = [0 ] if self .num_conditions == 1 else [0 , 1 ]
125126 data_batches = []
@@ -140,9 +141,9 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
140141
141142 if self .guidance_rescale > 0.0 :
142143 pred = rescale_noise_cfg (pred , pred_cond , self .guidance_rescale )
143-
144+
144145 return pred , {}
145-
146+
146147 @property
147148 def is_conditional (self ) -> bool :
148149 return self ._count_prepared == 1
@@ -157,17 +158,17 @@ def num_conditions(self) -> int:
157158 def _is_ag_enabled (self ) -> bool :
158159 if not self ._enabled :
159160 return False
160-
161+
161162 is_within_range = True
162163 if self ._num_inference_steps is not None :
163164 skip_start_step = int (self ._start * self ._num_inference_steps )
164165 skip_stop_step = int (self ._stop * self ._num_inference_steps )
165166 is_within_range = skip_start_step <= self ._step < skip_stop_step
166-
167+
167168 is_close = False
168169 if self .use_original_formulation :
169170 is_close = math .isclose (self .guidance_scale , 0.0 )
170171 else :
171172 is_close = math .isclose (self .guidance_scale , 1.0 )
172-
173+
173174 return is_within_range and not is_close
0 commit comments