@@ -100,6 +100,48 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100 return noise_cfg
101101
102102
103+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+ class MomentumBuffer :
105+ def __init__ (self , momentum : float ):
106+ self .momentum = momentum
107+ self .running_average = 0
108+
109+ def update (self , update_value : torch .Tensor ):
110+ new_average = self .momentum * self .running_average
111+ self .running_average = update_value + new_average
112+
113+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
114+ def normalized_guidance (
115+ pred_cond : torch .Tensor ,
116+ pred_uncond : torch .Tensor ,
117+ guidance_scale : float ,
118+ momentum_buffer : MomentumBuffer = None ,
119+ eta : float = 1.0 ,
120+ norm_threshold : float = 0.0 ,
121+ ):
122+ """
123+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
124+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
125+ """
126+ diff = pred_cond - pred_uncond
127+ if momentum_buffer is not None :
128+ momentum_buffer .update (diff )
129+ diff = momentum_buffer .running_average
130+ if norm_threshold > 0 :
131+ ones = torch .ones_like (diff )
132+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
133+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
134+ diff = diff * scale_factor
135+ v0 , v1 = diff .double (), pred_cond .double ()
136+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
137+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
138+ v0_orthogonal = v0 - v0_parallel
139+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
140+ normalized_update = diff_orthogonal + eta * diff_parallel
141+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
142+ return pred_guided
143+
144+
103145# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104146def retrieve_timesteps (
105147 scheduler ,
@@ -789,6 +831,14 @@ def guidance_scale(self):
789831 def guidance_rescale (self ):
790832 return self ._guidance_rescale
791833
834+ @property
835+ def adaptive_projected_guidance (self ):
836+ return self ._adaptive_projected_guidance
837+
838+ @property
839+ def adaptive_projected_guidance_momentum (self ):
840+ return self ._adaptive_projected_guidance_momentum
841+
792842 @property
793843 def clip_skip (self ):
794844 return self ._clip_skip
@@ -845,6 +895,8 @@ def __call__(
845895 return_dict : bool = True ,
846896 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
847897 guidance_rescale : float = 0.0 ,
898+ adaptive_projected_guidance : Optional [bool ] = None ,
899+ adaptive_projected_guidance_momentum : Optional [float ] = 0.0 ,
848900 original_size : Optional [Tuple [int , int ]] = None ,
849901 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
850902 target_size : Optional [Tuple [int , int ]] = None ,
@@ -956,6 +1008,11 @@ def __call__(
9561008 Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571009 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581010 Guidance rescale factor should fix overexposure when using zero terminal SNR.
1011+ adaptive_projected_guidance (`bool`, *optional*):
1012+ Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1013+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1014+ adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `0.0`):
1015+ Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
9591016 original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601017 If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611018 `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1106,8 @@ def __call__(
10491106
10501107 self ._guidance_scale = guidance_scale
10511108 self ._guidance_rescale = guidance_rescale
1109+ self ._adaptive_projected_guidance = adaptive_projected_guidance
1110+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
10521111 self ._clip_skip = clip_skip
10531112 self ._cross_attention_kwargs = cross_attention_kwargs
10541113 self ._denoising_end = denoising_end
@@ -1181,6 +1240,11 @@ def __call__(
11811240 guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
11821241 ).to (device = device , dtype = latents .dtype )
11831242
1243+ if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None :
1244+ momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
1245+ else :
1246+ momentum_buffer = None
1247+
11841248 self ._num_timesteps = len (timesteps )
11851249 with self .progress_bar (total = num_inference_steps ) as progress_bar :
11861250 for i , t in enumerate (timesteps ):
@@ -1209,7 +1273,12 @@ def __call__(
12091273 # perform guidance
12101274 if self .do_classifier_free_guidance :
12111275 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1212- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1276+ if adaptive_projected_guidance :
1277+ noise_pred = normalized_guidance (
1278+ noise_pred_text , noise_pred_uncond , self .guidance_scale , momentum_buffer
1279+ )
1280+ else :
1281+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
12131282
12141283 if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
12151284 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
0 commit comments