diff --git a/src/diffusers/pipelines/pag/pag_utils.py b/src/diffusers/pipelines/pag/pag_utils.py index 728f730c9904..7a6e30a3c6be 100644 --- a/src/diffusers/pipelines/pag/pag_utils.py +++ b/src/diffusers/pipelines/pag/pag_utils.py @@ -98,7 +98,9 @@ def _get_pag_scale(self, t): else: return self.pag_scale - def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_guidance, guidance_scale, t): + def _apply_perturbed_attention_guidance( + self, noise_pred, do_classifier_free_guidance, guidance_scale, t, return_pred_text=False + ): r""" Apply perturbed attention guidance to the noise prediction. @@ -107,9 +109,11 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui do_classifier_free_guidance (bool): Whether to apply classifier-free guidance. guidance_scale (float): The scale factor for the guidance term. t (int): The current time step. + return_pred_text (bool): Whether to return the text noise prediction. Returns: - torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance. + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: The updated noise prediction tensor after applying + perturbed attention guidance and the text noise prediction. """ pag_scale = self._get_pag_scale(t) if do_classifier_free_guidance: @@ -122,6 +126,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui else: noise_pred_text, noise_pred_perturb = noise_pred.chunk(2) noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb) + if return_pred_text: + return noise_pred, noise_pred_text return noise_pred def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance): diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index 63126cc5aae9..4663db3a15a1 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -893,8 +893,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd.py b/src/diffusers/pipelines/pag/pipeline_pag_sd.py index c6a4f7f42c84..e9742b08af50 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd.py @@ -993,8 +993,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py index 18fc06c1f9b8..8da4349594b4 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py @@ -1237,8 +1237,8 @@ def __call__( # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py index dc85aaaca37f..4c2c4e5aa3fa 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py @@ -1437,8 +1437,8 @@ def denoising_value_valid(dnv): # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py index f5ebf4300934..49e4c5ffd50c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py @@ -1649,8 +1649,8 @@ def denoising_value_valid(dnv): # perform guidance if self.do_perturbed_attention_guidance: - noise_pred = self._apply_perturbed_attention_guidance( - noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t + noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance( + noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t, True ) elif self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)