Skip to content

Commit 31931c4

Browse files
committed
Adaptive Projected Guidance
1 parent 31058cd commit 31931c4

File tree

2 files changed

+139
-2
lines changed

2 files changed

+139
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7878
return noise_cfg
7979

8080

81+
class MomentumBuffer:
82+
def __init__(self, momentum: float):
83+
self.momentum = momentum
84+
self.running_average = 0
85+
86+
def update(self, update_value: torch.Tensor):
87+
new_average = self.momentum * self.running_average
88+
self.running_average = update_value + new_average
89+
90+
91+
def normalized_guidance(
92+
pred_cond: torch.Tensor,
93+
pred_uncond: torch.Tensor,
94+
guidance_scale: float,
95+
momentum_buffer: MomentumBuffer = None,
96+
eta: float = 1.0,
97+
norm_threshold: float = 0.0,
98+
):
99+
"""
100+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
101+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
102+
"""
103+
diff = pred_cond - pred_uncond
104+
if momentum_buffer is not None:
105+
momentum_buffer.update(diff)
106+
diff = momentum_buffer.running_average
107+
if norm_threshold > 0:
108+
ones = torch.ones_like(diff)
109+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
110+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
111+
diff = diff * scale_factor
112+
v0, v1 = diff.double(), pred_cond.double()
113+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
114+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
115+
v0_orthogonal = v0 - v0_parallel
116+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
117+
normalized_update = diff_orthogonal + eta * diff_parallel
118+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
119+
return pred_guided
120+
121+
81122
def retrieve_timesteps(
82123
scheduler,
83124
num_inference_steps: Optional[int] = None,
@@ -730,6 +771,14 @@ def guidance_scale(self):
730771
def guidance_rescale(self):
731772
return self._guidance_rescale
732773

774+
@property
775+
def adaptive_projected_guidance(self):
776+
return self._adaptive_projected_guidance
777+
778+
@property
779+
def adaptive_projected_guidance_momentum(self):
780+
return self._adaptive_projected_guidance_momentum
781+
733782
@property
734783
def clip_skip(self):
735784
return self._clip_skip
@@ -777,6 +826,8 @@ def __call__(
777826
return_dict: bool = True,
778827
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779828
guidance_rescale: float = 0.0,
829+
adaptive_projected_guidance: Optional[bool] = None,
830+
adaptive_projected_guidance_momentum: Optional[float] = 0.0,
780831
clip_skip: Optional[int] = None,
781832
callback_on_step_end: Optional[
782833
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -847,6 +898,11 @@ def __call__(
847898
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848899
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849900
using zero terminal SNR.
901+
adaptive_projected_guidance (`bool`, *optional*):
902+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
903+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
904+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `0.0`):
905+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
850906
clip_skip (`int`, *optional*):
851907
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852908
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -910,6 +966,8 @@ def __call__(
910966

911967
self._guidance_scale = guidance_scale
912968
self._guidance_rescale = guidance_rescale
969+
self._adaptive_projected_guidance = adaptive_projected_guidance
970+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
913971
self._clip_skip = clip_skip
914972
self._cross_attention_kwargs = cross_attention_kwargs
915973
self._interrupt = False
@@ -992,6 +1050,11 @@ def __call__(
9921050
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
9931051
).to(device=device, dtype=latents.dtype)
9941052

1053+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1054+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1055+
else:
1056+
momentum_buffer = None
1057+
9951058
# 7. Denoising loop
9961059
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
9971060
self._num_timesteps = len(timesteps)
@@ -1018,7 +1081,12 @@ def __call__(
10181081
# perform guidance
10191082
if self.do_classifier_free_guidance:
10201083
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1021-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1084+
if adaptive_projected_guidance:
1085+
noise_pred = normalized_guidance(
1086+
noise_pred_text, noise_pred_uncond, self.guidance_scale, momentum_buffer
1087+
)
1088+
else:
1089+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10221090

10231091
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
10241092
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,48 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100
return noise_cfg
101101

102102

103+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+
class MomentumBuffer:
105+
def __init__(self, momentum: float):
106+
self.momentum = momentum
107+
self.running_average = 0
108+
109+
def update(self, update_value: torch.Tensor):
110+
new_average = self.momentum * self.running_average
111+
self.running_average = update_value + new_average
112+
113+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
114+
def normalized_guidance(
115+
pred_cond: torch.Tensor,
116+
pred_uncond: torch.Tensor,
117+
guidance_scale: float,
118+
momentum_buffer: MomentumBuffer = None,
119+
eta: float = 1.0,
120+
norm_threshold: float = 0.0,
121+
):
122+
"""
123+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
124+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
125+
"""
126+
diff = pred_cond - pred_uncond
127+
if momentum_buffer is not None:
128+
momentum_buffer.update(diff)
129+
diff = momentum_buffer.running_average
130+
if norm_threshold > 0:
131+
ones = torch.ones_like(diff)
132+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
133+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
134+
diff = diff * scale_factor
135+
v0, v1 = diff.double(), pred_cond.double()
136+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
137+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
138+
v0_orthogonal = v0 - v0_parallel
139+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
140+
normalized_update = diff_orthogonal + eta * diff_parallel
141+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
142+
return pred_guided
143+
144+
103145
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104146
def retrieve_timesteps(
105147
scheduler,
@@ -789,6 +831,14 @@ def guidance_scale(self):
789831
def guidance_rescale(self):
790832
return self._guidance_rescale
791833

834+
@property
835+
def adaptive_projected_guidance(self):
836+
return self._adaptive_projected_guidance
837+
838+
@property
839+
def adaptive_projected_guidance_momentum(self):
840+
return self._adaptive_projected_guidance_momentum
841+
792842
@property
793843
def clip_skip(self):
794844
return self._clip_skip
@@ -845,6 +895,8 @@ def __call__(
845895
return_dict: bool = True,
846896
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
847897
guidance_rescale: float = 0.0,
898+
adaptive_projected_guidance: Optional[bool] = None,
899+
adaptive_projected_guidance_momentum: Optional[float] = 0.0,
848900
original_size: Optional[Tuple[int, int]] = None,
849901
crops_coords_top_left: Tuple[int, int] = (0, 0),
850902
target_size: Optional[Tuple[int, int]] = None,
@@ -956,6 +1008,11 @@ def __call__(
9561008
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571009
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581010
Guidance rescale factor should fix overexposure when using zero terminal SNR.
1011+
adaptive_projected_guidance (`bool`, *optional*):
1012+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1013+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1014+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `0.0`):
1015+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
9591016
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601017
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611018
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1106,8 @@ def __call__(
10491106

10501107
self._guidance_scale = guidance_scale
10511108
self._guidance_rescale = guidance_rescale
1109+
self._adaptive_projected_guidance = adaptive_projected_guidance
1110+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
10521111
self._clip_skip = clip_skip
10531112
self._cross_attention_kwargs = cross_attention_kwargs
10541113
self._denoising_end = denoising_end
@@ -1181,6 +1240,11 @@ def __call__(
11811240
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
11821241
).to(device=device, dtype=latents.dtype)
11831242

1243+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1244+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1245+
else:
1246+
momentum_buffer = None
1247+
11841248
self._num_timesteps = len(timesteps)
11851249
with self.progress_bar(total=num_inference_steps) as progress_bar:
11861250
for i, t in enumerate(timesteps):
@@ -1209,7 +1273,12 @@ def __call__(
12091273
# perform guidance
12101274
if self.do_classifier_free_guidance:
12111275
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1212-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1276+
if adaptive_projected_guidance:
1277+
noise_pred = normalized_guidance(
1278+
noise_pred_text, noise_pred_uncond, self.guidance_scale, momentum_buffer
1279+
)
1280+
else:
1281+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
12131282

12141283
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
12151284
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

0 commit comments

Comments
 (0)