1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15- from  typing  import  Any , Dict , List , Optional , Union 
15+ import  math 
16+ from  typing  import  TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Union 
17+ 
18+ import  torch 
1619
1720from  ..configuration_utils  import  register_to_config 
18- from  ..hooks  import  LayerSkipConfig 
21+ from  ..hooks  import  HookRegistry , LayerSkipConfig 
22+ from  ..hooks .layer_skip  import  _apply_layer_skip_hook 
1923from  ..utils  import  get_logger 
20- from  .skip_layer_guidance  import  SkipLayerGuidance 
24+ from  .guider_utils  import  BaseGuidance , rescale_noise_cfg 
25+ 
26+ 
27+ if  TYPE_CHECKING :
28+     from  ..modular_pipelines .modular_pipeline  import  BlockState 
2129
2230
2331logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
2432
2533
26- class  PerturbedAttentionGuidance (SkipLayerGuidance ):
34+ class  PerturbedAttentionGuidance (BaseGuidance ):
2735    """ 
2836    Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377 
2937
@@ -36,7 +44,7 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
3644    Additional reading: 
3745    - [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507) 
3846
39-     PAG is implemented as a specialization of the  SkipLayerGuidance due to similarities  in the configuration parameters 
47+     PAG is implemented with similar implementation to  SkipLayerGuidance due to overlap  in the configuration parameters 
4048    and implementation details. 
4149
4250    Args: 
@@ -75,6 +83,8 @@ class PerturbedAttentionGuidance(SkipLayerGuidance):
7583    # complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation 
7684    # for each model architecture. 
7785
86+     _input_predictions  =  ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
87+ 
7888    @register_to_config  
7989    def  __init__ (
8090        self ,
@@ -89,6 +99,15 @@ def __init__(
8999        start : float  =  0.0 ,
90100        stop : float  =  1.0 ,
91101    ):
102+         super ().__init__ (start , stop )
103+ 
104+         self .guidance_scale  =  guidance_scale 
105+         self .skip_layer_guidance_scale  =  perturbed_guidance_scale 
106+         self .skip_layer_guidance_start  =  perturbed_guidance_start 
107+         self .skip_layer_guidance_stop  =  perturbed_guidance_stop 
108+         self .guidance_rescale  =  guidance_rescale 
109+         self .use_original_formulation  =  use_original_formulation 
110+ 
92111        if  perturbed_guidance_config  is  None :
93112            if  perturbed_guidance_layers  is  None :
94113                raise  ValueError (
@@ -130,15 +149,123 @@ def __init__(
130149            config .skip_attention_scores  =  True 
131150            config .skip_ff  =  False 
132151
133-         super ().__init__ (
134-             guidance_scale = guidance_scale ,
135-             skip_layer_guidance_scale = perturbed_guidance_scale ,
136-             skip_layer_guidance_start = perturbed_guidance_start ,
137-             skip_layer_guidance_stop = perturbed_guidance_stop ,
138-             skip_layer_guidance_layers = perturbed_guidance_layers ,
139-             skip_layer_config = perturbed_guidance_config ,
140-             guidance_rescale = guidance_rescale ,
141-             use_original_formulation = use_original_formulation ,
142-             start = start ,
143-             stop = stop ,
144-         )
152+         self .skip_layer_config  =  perturbed_guidance_config 
153+         self ._skip_layer_hook_names  =  [f"SkipLayerGuidance_{ i }   for  i  in  range (len (self .skip_layer_config ))]
154+ 
155+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models 
156+     def  prepare_models (self , denoiser : torch .nn .Module ) ->  None :
157+         self ._count_prepared  +=  1 
158+         if  self ._is_slg_enabled () and  self .is_conditional  and  self ._count_prepared  >  1 :
159+             for  name , config  in  zip (self ._skip_layer_hook_names , self .skip_layer_config ):
160+                 _apply_layer_skip_hook (denoiser , config , name = name )
161+ 
162+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models 
163+     def  cleanup_models (self , denoiser : torch .nn .Module ) ->  None :
164+         if  self ._is_slg_enabled () and  self .is_conditional  and  self ._count_prepared  >  1 :
165+             registry  =  HookRegistry .check_if_exists_or_initialize (denoiser )
166+             # Remove the hooks after inference 
167+             for  hook_name  in  self ._skip_layer_hook_names :
168+                 registry .remove_hook (hook_name , recurse = True )
169+ 
170+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs 
171+     def  prepare_inputs (
172+         self , data : "BlockState" , input_fields : Optional [Dict [str , Union [str , Tuple [str , str ]]]] =  None 
173+     ) ->  List ["BlockState" ]:
174+         if  input_fields  is  None :
175+             input_fields  =  self ._input_fields 
176+ 
177+         if  self .num_conditions  ==  1 :
178+             tuple_indices  =  [0 ]
179+             input_predictions  =  ["pred_cond" ]
180+         elif  self .num_conditions  ==  2 :
181+             tuple_indices  =  [0 , 1 ]
182+             input_predictions  =  (
183+                 ["pred_cond" , "pred_uncond" ] if  self ._is_cfg_enabled () else  ["pred_cond" , "pred_cond_skip" ]
184+             )
185+         else :
186+             tuple_indices  =  [0 , 1 , 0 ]
187+             input_predictions  =  ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
188+         data_batches  =  []
189+         for  i  in  range (self .num_conditions ):
190+             data_batch  =  self ._prepare_batch (input_fields , data , tuple_indices [i ], input_predictions [i ])
191+             data_batches .append (data_batch )
192+         return  data_batches 
193+ 
194+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward 
195+     def  forward (
196+         self ,
197+         pred_cond : torch .Tensor ,
198+         pred_uncond : Optional [torch .Tensor ] =  None ,
199+         pred_cond_skip : Optional [torch .Tensor ] =  None ,
200+     ) ->  torch .Tensor :
201+         pred  =  None 
202+ 
203+         if  not  self ._is_cfg_enabled () and  not  self ._is_slg_enabled ():
204+             pred  =  pred_cond 
205+         elif  not  self ._is_cfg_enabled ():
206+             shift  =  pred_cond  -  pred_cond_skip 
207+             pred  =  pred_cond  if  self .use_original_formulation  else  pred_cond_skip 
208+             pred  =  pred  +  self .skip_layer_guidance_scale  *  shift 
209+         elif  not  self ._is_slg_enabled ():
210+             shift  =  pred_cond  -  pred_uncond 
211+             pred  =  pred_cond  if  self .use_original_formulation  else  pred_uncond 
212+             pred  =  pred  +  self .guidance_scale  *  shift 
213+         else :
214+             shift  =  pred_cond  -  pred_uncond 
215+             shift_skip  =  pred_cond  -  pred_cond_skip 
216+             pred  =  pred_cond  if  self .use_original_formulation  else  pred_uncond 
217+             pred  =  pred  +  self .guidance_scale  *  shift  +  self .skip_layer_guidance_scale  *  shift_skip 
218+ 
219+         if  self .guidance_rescale  >  0.0 :
220+             pred  =  rescale_noise_cfg (pred , pred_cond , self .guidance_rescale )
221+ 
222+         return  pred , {}
223+ 
224+     @property  
225+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional 
226+     def  is_conditional (self ) ->  bool :
227+         return  self ._count_prepared  ==  1  or  self ._count_prepared  ==  3 
228+ 
229+     @property  
230+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions 
231+     def  num_conditions (self ) ->  int :
232+         num_conditions  =  1 
233+         if  self ._is_cfg_enabled ():
234+             num_conditions  +=  1 
235+         if  self ._is_slg_enabled ():
236+             num_conditions  +=  1 
237+         return  num_conditions 
238+ 
239+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled 
240+     def  _is_cfg_enabled (self ) ->  bool :
241+         if  not  self ._enabled :
242+             return  False 
243+ 
244+         is_within_range  =  True 
245+         if  self ._num_inference_steps  is  not None :
246+             skip_start_step  =  int (self ._start  *  self ._num_inference_steps )
247+             skip_stop_step  =  int (self ._stop  *  self ._num_inference_steps )
248+             is_within_range  =  skip_start_step  <=  self ._step  <  skip_stop_step 
249+ 
250+         is_close  =  False 
251+         if  self .use_original_formulation :
252+             is_close  =  math .isclose (self .guidance_scale , 0.0 )
253+         else :
254+             is_close  =  math .isclose (self .guidance_scale , 1.0 )
255+ 
256+         return  is_within_range  and  not  is_close 
257+ 
258+     # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled 
259+     def  _is_slg_enabled (self ) ->  bool :
260+         if  not  self ._enabled :
261+             return  False 
262+ 
263+         is_within_range  =  True 
264+         if  self ._num_inference_steps  is  not None :
265+             skip_start_step  =  int (self .skip_layer_guidance_start  *  self ._num_inference_steps )
266+             skip_stop_step  =  int (self .skip_layer_guidance_stop  *  self ._num_inference_steps )
267+             is_within_range  =  skip_start_step  <  self ._step  <  skip_stop_step 
268+ 
269+         is_zero  =  math .isclose (self .skip_layer_guidance_scale , 0.0 )
270+ 
271+         return  is_within_range  and  not  is_zero 
0 commit comments