@@ -78,6 +78,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7878 return noise_cfg
7979
8080
81+ class MomentumBuffer :
82+ def __init__ (self , momentum : float ):
83+ self .momentum = momentum
84+ self .running_average = 0
85+
86+ def update (self , update_value : torch .Tensor ):
87+ new_average = self .momentum * self .running_average
88+ self .running_average = update_value + new_average
89+
90+
91+ def normalized_guidance (
92+ pred_cond : torch .Tensor ,
93+ pred_uncond : torch .Tensor ,
94+ guidance_scale : float ,
95+ momentum_buffer : MomentumBuffer = None ,
96+ eta : float = 1.0 ,
97+ norm_threshold : float = 0.0 ,
98+ ):
99+ """
100+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
101+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
102+ """
103+ diff = pred_cond - pred_uncond
104+ if momentum_buffer is not None :
105+ momentum_buffer .update (diff )
106+ diff = momentum_buffer .running_average
107+ if norm_threshold > 0 :
108+ ones = torch .ones_like (diff )
109+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
110+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
111+ diff = diff * scale_factor
112+ v0 , v1 = diff .double (), pred_cond .double ()
113+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
114+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
115+ v0_orthogonal = v0 - v0_parallel
116+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
117+ normalized_update = diff_orthogonal + eta * diff_parallel
118+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
119+ return pred_guided
120+
121+
81122def retrieve_timesteps (
82123 scheduler ,
83124 num_inference_steps : Optional [int ] = None ,
@@ -730,6 +771,14 @@ def guidance_scale(self):
730771 def guidance_rescale (self ):
731772 return self ._guidance_rescale
732773
774+ @property
775+ def adaptive_projected_guidance (self ):
776+ return self ._adaptive_projected_guidance
777+
778+ @property
779+ def adaptive_projected_guidance_momentum (self ):
780+ return self ._adaptive_projected_guidance_momentum
781+
733782 @property
734783 def clip_skip (self ):
735784 return self ._clip_skip
@@ -777,6 +826,8 @@ def __call__(
777826 return_dict : bool = True ,
778827 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
779828 guidance_rescale : float = 0.0 ,
829+ adaptive_projected_guidance : Optional [bool ] = False ,
830+ adaptive_projected_guidance_momentum : Optional [float ] = - 0.75 ,
780831 clip_skip : Optional [int ] = None ,
781832 callback_on_step_end : Optional [
782833 Union [Callable [[int , int , Dict ], None ], PipelineCallback , MultiPipelineCallbacks ]
@@ -847,6 +898,11 @@ def __call__(
847898 Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848899 Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849900 using zero terminal SNR.
901+ adaptive_projected_guidance (`bool`, *optional*, defaults to `False`):
902+ Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
903+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
904+ adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.75`):
905+ Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
850906 clip_skip (`int`, *optional*):
851907 Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852908 the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -910,6 +966,8 @@ def __call__(
910966
911967 self ._guidance_scale = guidance_scale
912968 self ._guidance_rescale = guidance_rescale
969+ self ._adaptive_projected_guidance = adaptive_projected_guidance
970+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
913971 self ._clip_skip = clip_skip
914972 self ._cross_attention_kwargs = cross_attention_kwargs
915973 self ._interrupt = False
@@ -992,6 +1050,11 @@ def __call__(
9921050 guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
9931051 ).to (device = device , dtype = latents .dtype )
9941052
1053+ if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None :
1054+ momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
1055+ else :
1056+ momentum_buffer = None
1057+
9951058 # 7. Denoising loop
9961059 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
9971060 self ._num_timesteps = len (timesteps )
@@ -1018,7 +1081,12 @@ def __call__(
10181081 # perform guidance
10191082 if self .do_classifier_free_guidance :
10201083 noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1021- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1084+ if adaptive_projected_guidance :
1085+ noise_pred = normalized_guidance (
1086+ noise_pred_text , noise_pred_uncond , self .guidance_scale , momentum_buffer , eta
1087+ )
1088+ else :
1089+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
10221090
10231091 if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
10241092 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
0 commit comments