@@ -100,6 +100,48 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100 return noise_cfg
101101
102102
103+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+ class MomentumBuffer :
105+ def __init__ (self , momentum : float ):
106+ self .momentum = momentum
107+ self .running_average = 0
108+
109+ def update (self , update_value : torch .Tensor ):
110+ new_average = self .momentum * self .running_average
111+ self .running_average = update_value + new_average
112+
113+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
114+ def normalized_guidance (
115+ pred_cond : torch .Tensor ,
116+ pred_uncond : torch .Tensor ,
117+ guidance_scale : float ,
118+ momentum_buffer : MomentumBuffer = None ,
119+ eta : float = 1.0 ,
120+ norm_threshold : float = 0.0 ,
121+ ):
122+ """
123+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
124+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
125+ """
126+ diff = pred_cond - pred_uncond
127+ if momentum_buffer is not None :
128+ momentum_buffer .update (diff )
129+ diff = momentum_buffer .running_average
130+ if norm_threshold > 0 :
131+ ones = torch .ones_like (diff )
132+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
133+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
134+ diff = diff * scale_factor
135+ v0 , v1 = diff .double (), pred_cond .double ()
136+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
137+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
138+ v0_orthogonal = v0 - v0_parallel
139+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
140+ normalized_update = diff_orthogonal + eta * diff_parallel
141+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
142+ return pred_guided
143+
144+
103145# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104146def retrieve_timesteps (
105147 scheduler ,
@@ -789,6 +831,18 @@ def guidance_scale(self):
789831 def guidance_rescale (self ):
790832 return self ._guidance_rescale
791833
834+ @property
835+ def adaptive_projected_guidance (self ):
836+ return self ._adaptive_projected_guidance
837+
838+ @property
839+ def adaptive_projected_guidance_momentum (self ):
840+ return self ._adaptive_projected_guidance_momentum
841+
842+ @property
843+ def adaptive_projected_guidance_rescale_factor (self ):
844+ return self ._adaptive_projected_guidance_rescale_factor
845+
792846 @property
793847 def clip_skip (self ):
794848 return self ._clip_skip
@@ -845,6 +899,9 @@ def __call__(
845899 return_dict : bool = True ,
846900 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
847901 guidance_rescale : float = 0.0 ,
902+ adaptive_projected_guidance : Optional [bool ] = None ,
903+ adaptive_projected_guidance_momentum : Optional [float ] = 0.0 ,
904+ adaptive_projected_guidance_rescale_factor : Optional [float ] = 15.0 ,
848905 original_size : Optional [Tuple [int , int ]] = None ,
849906 crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
850907 target_size : Optional [Tuple [int , int ]] = None ,
@@ -956,6 +1013,13 @@ def __call__(
9561013 Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571014 [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581015 Guidance rescale factor should fix overexposure when using zero terminal SNR.
1016+ adaptive_projected_guidance (`bool`, *optional*):
1017+ Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1018+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1019+ adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
1020+ Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
1021+ adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
1022+ Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
9591023 original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601024 If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611025 `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1113,9 @@ def __call__(
10491113
10501114 self ._guidance_scale = guidance_scale
10511115 self ._guidance_rescale = guidance_rescale
1116+ self ._adaptive_projected_guidance = adaptive_projected_guidance
1117+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
1118+ self ._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
10521119 self ._clip_skip = clip_skip
10531120 self ._cross_attention_kwargs = cross_attention_kwargs
10541121 self ._denoising_end = denoising_end
@@ -1181,6 +1248,11 @@ def __call__(
11811248 guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
11821249 ).to (device = device , dtype = latents .dtype )
11831250
1251+ if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None :
1252+ momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
1253+ else :
1254+ momentum_buffer = None
1255+
11841256 self ._num_timesteps = len (timesteps )
11851257 with self .progress_bar (total = num_inference_steps ) as progress_bar :
11861258 for i , t in enumerate (timesteps ):
@@ -1209,7 +1281,12 @@ def __call__(
12091281 # perform guidance
12101282 if self .do_classifier_free_guidance :
12111283 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1212- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1284+ if adaptive_projected_guidance :
1285+ noise_pred = normalized_guidance (
1286+ noise_pred_text , noise_pred_uncond , self .guidance_scale , momentum_buffer , norm_threshold = adaptive_projected_guidance_rescale_factor
1287+ )
1288+ else :
1289+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
12131290
12141291 if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
12151292 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
0 commit comments