Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/diffusers/pipelines/pag/pag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
t (int): The current time step.

Returns:
torch.Tensor: The updated noise prediction tensor after applying perturbed attention guidance.
Tuple[torch.Tensor, torch.Tensor]: The updated noise prediction tensor after applying perturbed attention
guidance and the text noise prediction.
"""
pag_scale = self._get_pag_scale(t)
if do_classifier_free_guidance:
Expand All @@ -122,7 +123,7 @@ def _apply_perturbed_attention_guidance(self, noise_pred, do_classifier_free_gui
else:
noise_pred_text, noise_pred_perturb = noise_pred.chunk(2)
noise_pred = noise_pred_text + pag_scale * (noise_pred_text - noise_pred_perturb)
return noise_pred
return noise_pred, noise_pred_text
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like we broke the tests; I looked up the import of PAGMixin on github, not many but some so let's try to not making a breaking change here

we can add a argument return_pred_text; default to False but set True when calling from our pipelines.


def _prepare_perturbed_attention_guidance(self, cond, uncond, do_classifier_free_guidance):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_kolors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, do_classifier_free_guidance, guidance_scale, current_timestep
)
elif do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def __call__(

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,7 +1437,7 @@ def denoising_value_valid(dnv):

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def denoising_value_valid(dnv):

# perform guidance
if self.do_perturbed_attention_guidance:
noise_pred = self._apply_perturbed_attention_guidance(
noise_pred, noise_pred_text = self._apply_perturbed_attention_guidance(
noise_pred, self.do_classifier_free_guidance, self.guidance_scale, t
)
elif self.do_classifier_free_guidance:
Expand Down
Loading