@@ -87,9 +87,21 @@ def retrieve_latents(
8787
8888# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
8989def rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
90- """
91- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
92- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
90+ r"""
91+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
92+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
93+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
94+
95+ Args:
96+ noise_cfg (`torch.Tensor`):
97+ The predicted noise tensor for the guided diffusion process.
98+ noise_pred_text (`torch.Tensor`):
99+ The predicted noise tensor for the text-guided diffusion process.
100+ guidance_rescale (`float`, *optional*, defaults to 0.0):
101+ A rescale factor applied to the noise predictions.
102+
103+ Returns:
104+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
93105 """
94106 std_text = noise_pred_text .std (dim = list (range (1 , noise_pred_text .ndim )), keepdim = True )
95107 std_cfg = noise_cfg .std (dim = list (range (1 , noise_cfg .ndim )), keepdim = True )
@@ -109,7 +121,7 @@ def retrieve_timesteps(
109121 sigmas : Optional [List [float ]] = None ,
110122 ** kwargs ,
111123):
112- """
124+ r """
113125 Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
114126 custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
115127
@@ -804,8 +816,6 @@ def prepare_mask_latents(
804816 torch .cat ([masked_image_latents ] * 2 ) if do_classifier_free_guidance else masked_image_latents
805817 )
806818
807- # star
808-
809819 # aligning device to prevent device errors when concating it with the latent model input
810820 masked_image_latents = masked_image_latents .to (device = device , dtype = dtype )
811821 return mask , masked_image_latents
0 commit comments