@@ -109,26 +109,26 @@ def prepare_outputs(self, pred: torch.Tensor) -> None:
109109                key  =  "pred_perturbed" 
110110        self ._preds [key ] =  pred 
111111
112-         # Prepare denoiser for perturbed attention prediction if needed 
113-         if  not  self ._is_pag_enabled ():
114-             return 
115-         should_register_pag  =  (self ._is_cfg_enabled () and  self ._num_outputs_prepared  ==  2 ) or  (
116-             not  self ._is_cfg_enabled () and  self ._num_outputs_prepared  ==  1 
117-         )
118-         if  should_register_pag :
119-             self ._is_pag_batch  =  True 
120-             self ._original_processors  =  _replace_attention_processors (
121-                 self ._denoiser ,
122-                 self .pag_applied_layers ,
123-                 skip_context_attention = self .skip_context_attention ,
124-                 metadata_name = "perturbed_attention_guidance_processor_cls" ,
125-             )
126-         elif  self ._is_pag_batch :
127-             # Restore the original attention processors 
112+         # Restore the original attention processors if previously replaced 
113+         if  self ._is_pag_batch :
128114            _replace_attention_processors (self ._denoiser , processors = self ._original_processors )
129115            self ._is_pag_batch  =  False 
130116            self ._original_processors  =  None 
131117
118+         # Prepare denoiser for perturbed attention prediction if needed 
119+         if  self ._is_pag_enabled ():
120+             should_register_pag  =  (self ._is_cfg_enabled () and  self ._num_outputs_prepared  ==  2 ) or  (
121+                 not  self ._is_cfg_enabled () and  self ._num_outputs_prepared  ==  1 
122+             )
123+             if  should_register_pag :
124+                 self ._is_pag_batch  =  True 
125+                 self ._original_processors  =  _replace_attention_processors (
126+                     self ._denoiser ,
127+                     self .pag_applied_layers ,
128+                     skip_context_attention = self .skip_context_attention ,
129+                     metadata_name = "perturbed_attention_guidance_processor_cls" ,
130+                 )
131+ 
132132    def  cleanup_models (self , denoiser : torch .nn .Module ):
133133        self ._denoiser  =  None 
134134
0 commit comments