1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Any , Dict , List , Optional , Union
15+ import math
16+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union
17+
18+ import torch
1619
1720from ..configuration_utils import register_to_config
18- from ..hooks import LayerSkipConfig
21+ from ..hooks import HookRegistry , LayerSkipConfig
22+ from ..hooks .layer_skip import _apply_layer_skip_hook
1923from ..utils import get_logger
20- from .skip_layer_guidance import SkipLayerGuidance
24+ from .guider_utils import BaseGuidance , rescale_noise_cfg
25+
26+
27+ if TYPE_CHECKING :
28+ from ..modular_pipelines .modular_pipeline import BlockState
2129
2230
2331logger = get_logger (__name__ ) # pylint: disable=invalid-name
2432
2533
26- class PerturbedAttentionGuidance (SkipLayerGuidance ):
34+ class PerturbedAttentionGuidance (BaseGuidance ):
2735 """
2836 Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
2937
@@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
3644 Additional reading:
3745 - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
3846
39- PAG is implemented as a specialization of the SkipLayerGuidance due to similarities in the configuration parameters
47+ PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
4048 and implementation details.
4149
4250 Args:
@@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
7583 # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
7684 # for each model architecture.
7785
86+ _input_predictions = ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
87+
7888 @register_to_config
7989 def __init__ (
8090 self ,
@@ -89,6 +99,15 @@ def __init__(
8999 start : float = 0.0 ,
90100 stop : float = 1.0 ,
91101 ):
102+ super ().__init__ (start , stop )
103+
104+ self .guidance_scale = guidance_scale
105+ self .skip_layer_guidance_scale = perturbed_guidance_scale
106+ self .skip_layer_guidance_start = perturbed_guidance_start
107+ self .skip_layer_guidance_stop = perturbed_guidance_stop
108+ self .guidance_rescale = guidance_rescale
109+ self .use_original_formulation = use_original_formulation
110+
92111 if perturbed_guidance_config is None :
93112 if perturbed_guidance_layers is None :
94113 raise ValueError (
@@ -130,15 +149,123 @@ def __init__(
130149 config .skip_attention_scores = True
131150 config .skip_ff = False
132151
133- super ().__init__ (
134- guidance_scale = guidance_scale ,
135- skip_layer_guidance_scale = perturbed_guidance_scale ,
136- skip_layer_guidance_start = perturbed_guidance_start ,
137- skip_layer_guidance_stop = perturbed_guidance_stop ,
138- skip_layer_guidance_layers = perturbed_guidance_layers ,
139- skip_layer_config = perturbed_guidance_config ,
140- guidance_rescale = guidance_rescale ,
141- use_original_formulation = use_original_formulation ,
142- start = start ,
143- stop = stop ,
144- )
152+ self .skip_layer_config = perturbed_guidance_config
153+ self ._skip_layer_hook_names = [f"SkipLayerGuidance_{ i } " for i in range (len (self .skip_layer_config ))]
154+
155+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
156+ def prepare_models (self , denoiser : torch .nn .Module ) -> None :
157+ self ._count_prepared += 1
158+ if self ._is_slg_enabled () and self .is_conditional and self ._count_prepared > 1 :
159+ for name , config in zip (self ._skip_layer_hook_names , self .skip_layer_config ):
160+ _apply_layer_skip_hook (denoiser , config , name = name )
161+
162+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
163+ def cleanup_models (self , denoiser : torch .nn .Module ) -> None :
164+ if self ._is_slg_enabled () and self .is_conditional and self ._count_prepared > 1 :
165+ registry = HookRegistry .check_if_exists_or_initialize (denoiser )
166+ # Remove the hooks after inference
167+ for hook_name in self ._skip_layer_hook_names :
168+ registry .remove_hook (hook_name , recurse = True )
169+
170+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
171+ def prepare_inputs (
172+ self , data : "BlockState" , input_fields : Optional [Dict [str , Union [str , Tuple [str , str ]]]] = None
173+ ) -> List ["BlockState" ]:
174+ if input_fields is None :
175+ input_fields = self ._input_fields
176+
177+ if self .num_conditions == 1 :
178+ tuple_indices = [0 ]
179+ input_predictions = ["pred_cond" ]
180+ elif self .num_conditions == 2 :
181+ tuple_indices = [0 , 1 ]
182+ input_predictions = (
183+ ["pred_cond" , "pred_uncond" ] if self ._is_cfg_enabled () else ["pred_cond" , "pred_cond_skip" ]
184+ )
185+ else :
186+ tuple_indices = [0 , 1 , 0 ]
187+ input_predictions = ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
188+ data_batches = []
189+ for i in range (self .num_conditions ):
190+ data_batch = self ._prepare_batch (input_fields , data , tuple_indices [i ], input_predictions [i ])
191+ data_batches .append (data_batch )
192+ return data_batches
193+
194+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
195+ def forward (
196+ self ,
197+ pred_cond : torch .Tensor ,
198+ pred_uncond : Optional [torch .Tensor ] = None ,
199+ pred_cond_skip : Optional [torch .Tensor ] = None ,
200+ ) -> torch .Tensor :
201+ pred = None
202+
203+ if not self ._is_cfg_enabled () and not self ._is_slg_enabled ():
204+ pred = pred_cond
205+ elif not self ._is_cfg_enabled ():
206+ shift = pred_cond - pred_cond_skip
207+ pred = pred_cond if self .use_original_formulation else pred_cond_skip
208+ pred = pred + self .skip_layer_guidance_scale * shift
209+ elif not self ._is_slg_enabled ():
210+ shift = pred_cond - pred_uncond
211+ pred = pred_cond if self .use_original_formulation else pred_uncond
212+ pred = pred + self .guidance_scale * shift
213+ else :
214+ shift = pred_cond - pred_uncond
215+ shift_skip = pred_cond - pred_cond_skip
216+ pred = pred_cond if self .use_original_formulation else pred_uncond
217+ pred = pred + self .guidance_scale * shift + self .skip_layer_guidance_scale * shift_skip
218+
219+ if self .guidance_rescale > 0.0 :
220+ pred = rescale_noise_cfg (pred , pred_cond , self .guidance_rescale )
221+
222+ return pred , {}
223+
224+ @property
225+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
226+ def is_conditional (self ) -> bool :
227+ return self ._count_prepared == 1 or self ._count_prepared == 3
228+
229+ @property
230+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
231+ def num_conditions (self ) -> int :
232+ num_conditions = 1
233+ if self ._is_cfg_enabled ():
234+ num_conditions += 1
235+ if self ._is_slg_enabled ():
236+ num_conditions += 1
237+ return num_conditions
238+
239+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
240+ def _is_cfg_enabled (self ) -> bool :
241+ if not self ._enabled :
242+ return False
243+
244+ is_within_range = True
245+ if self ._num_inference_steps is not None :
246+ skip_start_step = int (self ._start * self ._num_inference_steps )
247+ skip_stop_step = int (self ._stop * self ._num_inference_steps )
248+ is_within_range = skip_start_step <= self ._step < skip_stop_step
249+
250+ is_close = False
251+ if self .use_original_formulation :
252+ is_close = math .isclose (self .guidance_scale , 0.0 )
253+ else :
254+ is_close = math .isclose (self .guidance_scale , 1.0 )
255+
256+ return is_within_range and not is_close
257+
258+ # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
259+ def _is_slg_enabled (self ) -> bool :
260+ if not self ._enabled :
261+ return False
262+
263+ is_within_range = True
264+ if self ._num_inference_steps is not None :
265+ skip_start_step = int (self .skip_layer_guidance_start * self ._num_inference_steps )
266+ skip_stop_step = int (self .skip_layer_guidance_stop * self ._num_inference_steps )
267+ is_within_range = skip_start_step < self ._step < skip_stop_step
268+
269+ is_zero = math .isclose (self .skip_layer_guidance_scale , 0.0 )
270+
271+ return is_within_range and not is_zero
0 commit comments