@@ -1471,6 +1471,14 @@ def denoising_value_valid(dnv):
14711471 generator ,
14721472 self .do_classifier_free_guidance ,
14731473 )
1474+ if self .do_perturbed_attention_guidance :
1475+ if self .do_classifier_free_guidance :
1476+ mask , _ = mask .chunk (2 )
1477+ masked_image_latents , _ = masked_image_latents .chunk (2 )
1478+ mask = self ._prepare_perturbed_attention_guidance (mask , mask , self .do_classifier_free_guidance )
1479+ masked_image_latents = self ._prepare_perturbed_attention_guidance (
1480+ masked_image_latents , masked_image_latents , self .do_classifier_free_guidance
1481+ )
14741482
14751483 # 8. Check that sizes of mask, masked image and latents match
14761484 if num_channels_unet == 9 :
@@ -1659,10 +1667,10 @@ def denoising_value_valid(dnv):
16591667
16601668 if num_channels_unet == 4 :
16611669 init_latents_proper = image_latents
1662- if self .do_classifier_free_guidance :
1663- init_mask , _ = mask .chunk (2 )
1670+ if self .do_perturbed_attention_guidance :
1671+ init_mask , * _ = mask . chunk ( 3 ) if self . do_classifier_free_guidance else mask .chunk (2 )
16641672 else :
1665- init_mask = mask
1673+ init_mask , * _ = mask . chunk ( 2 ) if self . do_classifier_free_guidance else mask
16661674
16671675 if i < len (timesteps ) - 1 :
16681676 noise_timestep = timesteps [i + 1 ]
0 commit comments