Skip to content

Commit 7aa3914

Browse files
committed
make
1 parent 22f5460 commit 7aa3914

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

src/diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,21 @@ def retrieve_latents(
8787

8888
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
8989
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
90-
"""
91-
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
92-
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
90+
r"""
91+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
92+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
93+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
94+
95+
Args:
96+
noise_cfg (`torch.Tensor`):
97+
The predicted noise tensor for the guided diffusion process.
98+
noise_pred_text (`torch.Tensor`):
99+
The predicted noise tensor for the text-guided diffusion process.
100+
guidance_rescale (`float`, *optional*, defaults to 0.0):
101+
A rescale factor applied to the noise predictions.
102+
103+
Returns:
104+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
93105
"""
94106
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
95107
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -109,7 +121,7 @@ def retrieve_timesteps(
109121
sigmas: Optional[List[float]] = None,
110122
**kwargs,
111123
):
112-
"""
124+
r"""
113125
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114126
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115127
@@ -804,8 +816,6 @@ def prepare_mask_latents(
804816
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
805817
)
806818

807-
# star
808-
809819
# aligning device to prevent device errors when concating it with the latent model input
810820
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
811821
return mask, masked_image_latents

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,6 +1757,21 @@ def from_pretrained(cls, *args, **kwargs):
17571757
requires_backends(cls, ["torch", "transformers"])
17581758

17591759

1760+
class StableDiffusionPAGInpaintPipeline(metaclass=DummyObject):
1761+
_backends = ["torch", "transformers"]
1762+
1763+
def __init__(self, *args, **kwargs):
1764+
requires_backends(self, ["torch", "transformers"])
1765+
1766+
@classmethod
1767+
def from_config(cls, *args, **kwargs):
1768+
requires_backends(cls, ["torch", "transformers"])
1769+
1770+
@classmethod
1771+
def from_pretrained(cls, *args, **kwargs):
1772+
requires_backends(cls, ["torch", "transformers"])
1773+
1774+
17601775
class StableDiffusionPAGPipeline(metaclass=DummyObject):
17611776
_backends = ["torch", "transformers"]
17621777

0 commit comments

Comments
 (0)