@@ -100,6 +100,49 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100 return noise_cfg
101101
102102
103+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+ class MomentumBuffer :
105+ def __init__ (self , momentum : float ):
106+ self .momentum = momentum
107+ self .running_average = 0
108+
109+ def update (self , update_value : torch .Tensor ):
110+ new_average = self .momentum * self .running_average
111+ self .running_average = update_value + new_average
112+
113+
114+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
115+ def normalized_guidance (
116+ pred_cond : torch .Tensor ,
117+ pred_uncond : torch .Tensor ,
118+ guidance_scale : float ,
119+ momentum_buffer : MomentumBuffer = None ,
120+ eta : float = 1.0 ,
121+ norm_threshold : float = 0.0 ,
122+ ):
123+ """
124+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
125+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
126+ """
127+ diff = pred_cond - pred_uncond
128+ if momentum_buffer is not None :
129+ momentum_buffer .update (diff )
130+ diff = momentum_buffer .running_average
131+ if norm_threshold > 0 :
132+ ones = torch .ones_like (diff )
133+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
134+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
135+ diff = diff * scale_factor
136+ v0 , v1 = diff .double (), pred_cond .double ()
137+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
138+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
139+ v0_orthogonal = v0 - v0_parallel
140+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
141+ normalized_update = diff_orthogonal + eta * diff_parallel
142+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
143+ return pred_guided
144+
145+
103146# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104147def retrieve_timesteps (
105148 scheduler ,
@@ -789,6 +832,18 @@ def guidance_scale(self):
789832 def guidance_rescale (self ):
790833 return self ._guidance_rescale
791834
835+ @property
836+ def adaptive_projected_guidance (self ):
837+ return self ._adaptive_projected_guidance
838+
839+ @property
840+ def adaptive_projected_guidance_momentum (self ):
841+ return self ._adaptive_projected_guidance_momentum
842+
843+ @property
844+ def adaptive_projected_guidance_rescale_factor (self ):
845+ return self ._adaptive_projected_guidance_rescale_factor
846+
792847 @property
793848 def clip_skip (self ):
794849 return self ._clip_skip
@@ -845,6 +900,9 @@ def __call__(
845900 return_dict : bool = True ,
846901 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
847902 guidance_rescale : float = 0.0 ,
903+ adaptive_projected_guidance : Optional [bool ] = None ,
904+ adaptive_projected_guidance_momentum : Optional [float ] = - 0.5 ,
905+ adaptive_projected_guidance_rescale_factor : Optional [float ] = 15.0 ,
848906 original_size : Optional [Tuple [int , int ]] = None ,
849907 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
850908 target_size : Optional [Tuple [int , int ]] = None ,
@@ -956,6 +1014,13 @@ def __call__(
9561014 Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571015 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581016 Guidance rescale factor should fix overexposure when using zero terminal SNR.
1017+ adaptive_projected_guidance (`bool`, *optional*):
1018+ Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1019+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1020+ adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
1021+ Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
1022+ adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
1023+ Rescale factor to use with adaptive projected guidance.
9591024 original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601025 If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611026 `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1114,9 @@ def __call__(
10491114
10501115 self ._guidance_scale = guidance_scale
10511116 self ._guidance_rescale = guidance_rescale
1117+ self ._adaptive_projected_guidance = adaptive_projected_guidance
1118+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
1119+ self ._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
10521120 self ._clip_skip = clip_skip
10531121 self ._cross_attention_kwargs = cross_attention_kwargs
10541122 self ._denoising_end = denoising_end
@@ -1181,6 +1249,11 @@ def __call__(
11811249 guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
11821250 ).to (device = device , dtype = latents .dtype )
11831251
1252+ if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None :
1253+ momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
1254+ else :
1255+ momentum_buffer = None
1256+
11841257 self ._num_timesteps = len (timesteps )
11851258 with self .progress_bar (total = num_inference_steps ) as progress_bar :
11861259 for i , t in enumerate (timesteps ):
@@ -1209,7 +1282,17 @@ def __call__(
12091282 # perform guidance
12101283 if self .do_classifier_free_guidance :
12111284 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1212- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1285+ if adaptive_projected_guidance :
1286+ noise_pred = normalized_guidance (
1287+ noise_pred_text ,
1288+ noise_pred_uncond ,
1289+ self .guidance_scale ,
1290+ momentum_buffer ,
1291+ eta ,
1292+ adaptive_projected_guidance_rescale_factor ,
1293+ )
1294+ else :
1295+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
12131296
12141297 if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
12151298 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
0 commit comments