Skip to content

Commit 44db783

Browse files
committed
Adaptive Projected Guidance
1 parent 31058cd commit 44db783

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7878
return noise_cfg
7979

8080

81+
class MomentumBuffer:
82+
def __init__(self, momentum: float):
83+
self.momentum = momentum
84+
self.running_average = 0
85+
86+
def update(self, update_value: torch.Tensor):
87+
new_average = self.momentum * self.running_average
88+
self.running_average = update_value + new_average
89+
90+
91+
def normalized_guidance(
92+
pred_cond: torch.Tensor,
93+
pred_uncond: torch.Tensor,
94+
guidance_scale: float,
95+
momentum_buffer: MomentumBuffer = None,
96+
eta: float = 1.0,
97+
norm_threshold: float = 0.0,
98+
):
99+
"""
100+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
101+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
102+
"""
103+
diff = pred_cond - pred_uncond
104+
if momentum_buffer is not None:
105+
momentum_buffer.update(diff)
106+
diff = momentum_buffer.running_average
107+
if norm_threshold > 0:
108+
ones = torch.ones_like(diff)
109+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
110+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
111+
diff = diff * scale_factor
112+
v0, v1 = diff.double(), pred_cond.double()
113+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
114+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
115+
v0_orthogonal = v0 - v0_parallel
116+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
117+
normalized_update = diff_orthogonal + eta * diff_parallel
118+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
119+
return pred_guided
120+
121+
81122
def retrieve_timesteps(
82123
scheduler,
83124
num_inference_steps: Optional[int] = None,
@@ -730,6 +771,14 @@ def guidance_scale(self):
730771
def guidance_rescale(self):
731772
return self._guidance_rescale
732773

774+
@property
775+
def adaptive_projected_guidance(self):
776+
return self._adaptive_projected_guidance
777+
778+
@property
779+
def adaptive_projected_guidance_momentum(self):
780+
return self._adaptive_projected_guidance_momentum
781+
733782
@property
734783
def clip_skip(self):
735784
return self._clip_skip
@@ -777,6 +826,8 @@ def __call__(
777826
return_dict: bool = True,
778827
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779828
guidance_rescale: float = 0.0,
829+
adaptive_projected_guidance: Optional[bool] = False,
830+
adaptive_projected_guidance_momentum: Optional[float] = -0.75,
780831
clip_skip: Optional[int] = None,
781832
callback_on_step_end: Optional[
782833
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -847,6 +898,11 @@ def __call__(
847898
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848899
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849900
using zero terminal SNR.
901+
adaptive_projected_guidance (`bool`, *optional*, defaults to `False`):
902+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
903+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
904+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.75`):
905+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
850906
clip_skip (`int`, *optional*):
851907
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852908
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -910,6 +966,8 @@ def __call__(
910966

911967
self._guidance_scale = guidance_scale
912968
self._guidance_rescale = guidance_rescale
969+
self._adaptive_projected_guidance = adaptive_projected_guidance
970+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
913971
self._clip_skip = clip_skip
914972
self._cross_attention_kwargs = cross_attention_kwargs
915973
self._interrupt = False
@@ -992,6 +1050,11 @@ def __call__(
9921050
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
9931051
).to(device=device, dtype=latents.dtype)
9941052

1053+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1054+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1055+
else:
1056+
momentum_buffer = None
1057+
9951058
# 7. Denoising loop
9961059
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
9971060
self._num_timesteps = len(timesteps)
@@ -1018,7 +1081,12 @@ def __call__(
10181081
# perform guidance
10191082
if self.do_classifier_free_guidance:
10201083
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1021-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1084+
if adaptive_projected_guidance:
1085+
noise_pred = normalized_guidance(
1086+
noise_pred_text, noise_pred_uncond, self.guidance_scale, momentum_buffer, eta
1087+
)
1088+
else:
1089+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10221090

10231091
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
10241092
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

0 commit comments

Comments
 (0)