1- # Implementation of StableDiffusionPAGPipeline
1+ # Implementation of StableDiffusionPipeline with PAG
2+ # https://ku-cvlab.github.io/Perturbed-Attention-Guidance
23
34import inspect
45from typing import Any , Callable , Dict , List , Optional , Union
@@ -134,8 +135,8 @@ def __call__(
134135
135136 value = attn .to_v (hidden_states_ptb )
136137
137- hidden_states_ptb = torch .zeros (value .shape ).to (value .get_device ())
138- # hidden_states_ptb = value
138+ # hidden_states_ptb = torch.zeros(value.shape).to(value.get_device())
139+ hidden_states_ptb = value
139140
140141 hidden_states_ptb = hidden_states_ptb .to (query .dtype )
141142
@@ -1045,7 +1046,7 @@ def pag_scale(self):
10451046 return self ._pag_scale
10461047
10471048 @property
1048- def do_adversarial_guidance (self ):
1049+ def do_perturbed_attention_guidance (self ):
10491050 return self ._pag_scale > 0
10501051
10511052 @property
@@ -1056,14 +1057,6 @@ def pag_adaptive_scaling(self):
10561057 def do_pag_adaptive_scaling (self ):
10571058 return self ._pag_adaptive_scaling > 0
10581059
1059- @property
1060- def pag_drop_rate (self ):
1061- return self ._pag_drop_rate
1062-
1063- @property
1064- def pag_applied_layers (self ):
1065- return self ._pag_applied_layers
1066-
10671060 @property
10681061 def pag_applied_layers_index (self ):
10691062 return self ._pag_applied_layers_index
@@ -1080,8 +1073,6 @@ def __call__(
10801073 guidance_scale : float = 7.5 ,
10811074 pag_scale : float = 0.0 ,
10821075 pag_adaptive_scaling : float = 0.0 ,
1083- pag_drop_rate : float = 0.5 ,
1084- pag_applied_layers : List [str ] = ["down" ], # ['down', 'mid', 'up']
10851076 pag_applied_layers_index : List [str ] = ["d4" ], # ['d4', 'd5', 'm0']
10861077 negative_prompt : Optional [Union [str , List [str ]]] = None ,
10871078 num_images_per_prompt : Optional [int ] = 1 ,
@@ -1221,8 +1212,6 @@ def __call__(
12211212
12221213 self ._pag_scale = pag_scale
12231214 self ._pag_adaptive_scaling = pag_adaptive_scaling
1224- self ._pag_drop_rate = pag_drop_rate
1225- self ._pag_applied_layers = pag_applied_layers
12261215 self ._pag_applied_layers_index = pag_applied_layers_index
12271216
12281217 # 2. Define call parameters
@@ -1257,13 +1246,13 @@ def __call__(
12571246 # to avoid doing two forward passes
12581247
12591248 # cfg
1260- if self .do_classifier_free_guidance and not self .do_adversarial_guidance :
1249+ if self .do_classifier_free_guidance and not self .do_perturbed_attention_guidance :
12611250 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ])
12621251 # pag
1263- elif not self .do_classifier_free_guidance and self .do_adversarial_guidance :
1252+ elif not self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
12641253 prompt_embeds = torch .cat ([prompt_embeds , prompt_embeds ])
12651254 # both
1266- elif self .do_classifier_free_guidance and self .do_adversarial_guidance :
1255+ elif self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
12671256 prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds , prompt_embeds ])
12681257
12691258 if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
@@ -1306,7 +1295,7 @@ def __call__(
13061295 ).to (device = device , dtype = latents .dtype )
13071296
13081297 # 7. Denoising loop
1309- if self .do_adversarial_guidance :
1298+ if self .do_perturbed_attention_guidance :
13101299 down_layers = []
13111300 mid_layers = []
13121301 up_layers = []
@@ -1322,6 +1311,29 @@ def __call__(
13221311 else :
13231312 raise ValueError (f"Invalid layer type: { layer_type } " )
13241313
1314+ # change attention layer in UNet if use PAG
1315+ if self .do_perturbed_attention_guidance :
1316+ if self .do_classifier_free_guidance :
1317+ replace_processor = PAGCFGIdentitySelfAttnProcessor ()
1318+ else :
1319+ replace_processor = PAGIdentitySelfAttnProcessor ()
1320+
1321+ drop_layers = self .pag_applied_layers_index
1322+ for drop_layer in drop_layers :
1323+ try :
1324+ if drop_layer [0 ] == "d" :
1325+ down_layers [int (drop_layer [1 ])].processor = replace_processor
1326+ elif drop_layer [0 ] == "m" :
1327+ mid_layers [int (drop_layer [1 ])].processor = replace_processor
1328+ elif drop_layer [0 ] == "u" :
1329+ up_layers [int (drop_layer [1 ])].processor = replace_processor
1330+ else :
1331+ raise ValueError (f"Invalid layer type: { drop_layer [0 ]} " )
1332+ except IndexError :
1333+ raise ValueError (
1334+ f"Invalid layer index: { drop_layer } . Available layers: { len (down_layers )} down layers, { len (mid_layers )} mid layers, { len (up_layers )} up layers."
1335+ )
1336+
13251337 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
13261338 self ._num_timesteps = len (timesteps )
13271339 with self .progress_bar (total = num_inference_steps ) as progress_bar :
@@ -1330,41 +1342,18 @@ def __call__(
13301342 continue
13311343
13321344 # cfg
1333- if self .do_classifier_free_guidance and not self .do_adversarial_guidance :
1345+ if self .do_classifier_free_guidance and not self .do_perturbed_attention_guidance :
13341346 latent_model_input = torch .cat ([latents ] * 2 )
13351347 # pag
1336- elif not self .do_classifier_free_guidance and self .do_adversarial_guidance :
1348+ elif not self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
13371349 latent_model_input = torch .cat ([latents ] * 2 )
13381350 # both
1339- elif self .do_classifier_free_guidance and self .do_adversarial_guidance :
1351+ elif self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
13401352 latent_model_input = torch .cat ([latents ] * 3 )
13411353 # no
13421354 else :
13431355 latent_model_input = latents
13441356
1345- # change attention layer in UNet if use PAG
1346- if self .do_adversarial_guidance :
1347- if self .do_classifier_free_guidance :
1348- replace_processor = PAGCFGIdentitySelfAttnProcessor ()
1349- else :
1350- replace_processor = PAGIdentitySelfAttnProcessor ()
1351-
1352- drop_layers = self .pag_applied_layers_index
1353- for drop_layer in drop_layers :
1354- try :
1355- if drop_layer [0 ] == "d" :
1356- down_layers [int (drop_layer [1 ])].processor = replace_processor
1357- elif drop_layer [0 ] == "m" :
1358- mid_layers [int (drop_layer [1 ])].processor = replace_processor
1359- elif drop_layer [0 ] == "u" :
1360- up_layers [int (drop_layer [1 ])].processor = replace_processor
1361- else :
1362- raise ValueError (f"Invalid layer type: { drop_layer [0 ]} " )
1363- except IndexError :
1364- raise ValueError (
1365- f"Invalid layer index: { drop_layer } . Available layers: { len (down_layers )} down layers, { len (mid_layers )} mid layers, { len (up_layers )} up layers."
1366- )
1367-
13681357 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
13691358
13701359 # predict the noise residual
@@ -1381,14 +1370,14 @@ def __call__(
13811370 # perform guidance
13821371
13831372 # cfg
1384- if self .do_classifier_free_guidance and not self .do_adversarial_guidance :
1373+ if self .do_classifier_free_guidance and not self .do_perturbed_attention_guidance :
13851374 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
13861375
13871376 delta = noise_pred_text - noise_pred_uncond
13881377 noise_pred = noise_pred_uncond + self .guidance_scale * delta
13891378
13901379 # pag
1391- elif not self .do_classifier_free_guidance and self .do_adversarial_guidance :
1380+ elif not self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
13921381 noise_pred_original , noise_pred_perturb = noise_pred .chunk (2 )
13931382
13941383 signal_scale = self .pag_scale
@@ -1400,7 +1389,7 @@ def __call__(
14001389 noise_pred = noise_pred_original + signal_scale * (noise_pred_original - noise_pred_perturb )
14011390
14021391 # both
1403- elif self .do_classifier_free_guidance and self .do_adversarial_guidance :
1392+ elif self .do_classifier_free_guidance and self .do_perturbed_attention_guidance :
14041393 noise_pred_uncond , noise_pred_text , noise_pred_text_perturb = noise_pred .chunk (3 )
14051394
14061395 signal_scale = self .pag_scale
@@ -1458,11 +1447,8 @@ def __call__(
14581447 # Offload all models
14591448 self .maybe_free_model_hooks ()
14601449
1461- if not return_dict :
1462- return (image , has_nsfw_concept )
1463-
14641450 # change attention layer in UNet if use PAG
1465- if self .do_adversarial_guidance :
1451+ if self .do_perturbed_attention_guidance :
14661452 drop_layers = self .pag_applied_layers_index
14671453 for drop_layer in drop_layers :
14681454 try :
@@ -1479,4 +1465,7 @@ def __call__(
14791465 f"Invalid layer index: { drop_layer } . Available layers: { len (down_layers )} down layers, { len (mid_layers )} mid layers, { len (up_layers )} up layers."
14801466 )
14811467
1468+ if not return_dict :
1469+ return (image , has_nsfw_concept )
1470+
14821471 return StableDiffusionPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
0 commit comments