Skip to content

Commit bbf3466

Browse files
committed
Adaptive Projected Guidance
1 parent 31058cd commit bbf3466

File tree

2 files changed

+147
-2
lines changed

2 files changed

+147
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7878
return noise_cfg
7979

8080

81+
class MomentumBuffer:
82+
def __init__(self, momentum: float):
83+
self.momentum = momentum
84+
self.running_average = 0
85+
86+
def update(self, update_value: torch.Tensor):
87+
new_average = self.momentum * self.running_average
88+
self.running_average = update_value + new_average
89+
90+
91+
def normalized_guidance(
92+
pred_cond: torch.Tensor,
93+
pred_uncond: torch.Tensor,
94+
guidance_scale: float,
95+
momentum_buffer: MomentumBuffer = None,
96+
eta: float = 1.0,
97+
norm_threshold: float = 0.0,
98+
):
99+
"""
100+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
101+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
102+
"""
103+
diff = pred_cond - pred_uncond
104+
if momentum_buffer is not None:
105+
momentum_buffer.update(diff)
106+
diff = momentum_buffer.running_average
107+
if norm_threshold > 0:
108+
ones = torch.ones_like(diff)
109+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
110+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
111+
diff = diff * scale_factor
112+
v0, v1 = diff.double(), pred_cond.double()
113+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
114+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
115+
v0_orthogonal = v0 - v0_parallel
116+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
117+
normalized_update = diff_orthogonal + eta * diff_parallel
118+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
119+
return pred_guided
120+
121+
81122
def retrieve_timesteps(
82123
scheduler,
83124
num_inference_steps: Optional[int] = None,
@@ -730,6 +771,14 @@ def guidance_scale(self):
730771
def guidance_rescale(self):
731772
return self._guidance_rescale
732773

774+
@property
775+
def adaptive_projected_guidance(self):
776+
return self._adaptive_projected_guidance
777+
778+
@property
779+
def adaptive_projected_guidance_momentum(self):
780+
return self._adaptive_projected_guidance_momentum
781+
733782
@property
734783
def clip_skip(self):
735784
return self._clip_skip
@@ -777,6 +826,8 @@ def __call__(
777826
return_dict: bool = True,
778827
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779828
guidance_rescale: float = 0.0,
829+
adaptive_projected_guidance: Optional[bool] = None,
830+
adaptive_projected_guidance_momentum: Optional[float] = 0.0,
780831
clip_skip: Optional[int] = None,
781832
callback_on_step_end: Optional[
782833
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -847,6 +898,11 @@ def __call__(
847898
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848899
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849900
using zero terminal SNR.
901+
adaptive_projected_guidance (`bool`, *optional*):
902+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
903+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
904+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `0.0`):
905+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
850906
clip_skip (`int`, *optional*):
851907
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852908
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -910,6 +966,8 @@ def __call__(
910966

911967
self._guidance_scale = guidance_scale
912968
self._guidance_rescale = guidance_rescale
969+
self._adaptive_projected_guidance = adaptive_projected_guidance
970+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
913971
self._clip_skip = clip_skip
914972
self._cross_attention_kwargs = cross_attention_kwargs
915973
self._interrupt = False
@@ -992,6 +1050,11 @@ def __call__(
9921050
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
9931051
).to(device=device, dtype=latents.dtype)
9941052

1053+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1054+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1055+
else:
1056+
momentum_buffer = None
1057+
9951058
# 7. Denoising loop
9961059
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
9971060
self._num_timesteps = len(timesteps)
@@ -1018,7 +1081,12 @@ def __call__(
10181081
# perform guidance
10191082
if self.do_classifier_free_guidance:
10201083
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1021-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1084+
if adaptive_projected_guidance:
1085+
noise_pred = normalized_guidance(
1086+
noise_pred_text, noise_pred_uncond, self.guidance_scale, momentum_buffer
1087+
)
1088+
else:
1089+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10221090

10231091
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
10241092
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,48 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100
return noise_cfg
101101

102102

103+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+
class MomentumBuffer:
105+
def __init__(self, momentum: float):
106+
self.momentum = momentum
107+
self.running_average = 0
108+
109+
def update(self, update_value: torch.Tensor):
110+
new_average = self.momentum * self.running_average
111+
self.running_average = update_value + new_average
112+
113+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
114+
def normalized_guidance(
115+
pred_cond: torch.Tensor,
116+
pred_uncond: torch.Tensor,
117+
guidance_scale: float,
118+
momentum_buffer: MomentumBuffer = None,
119+
eta: float = 1.0,
120+
norm_threshold: float = 0.0,
121+
):
122+
"""
123+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
124+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
125+
"""
126+
diff = pred_cond - pred_uncond
127+
if momentum_buffer is not None:
128+
momentum_buffer.update(diff)
129+
diff = momentum_buffer.running_average
130+
if norm_threshold > 0:
131+
ones = torch.ones_like(diff)
132+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
133+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
134+
diff = diff * scale_factor
135+
v0, v1 = diff.double(), pred_cond.double()
136+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
137+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
138+
v0_orthogonal = v0 - v0_parallel
139+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
140+
normalized_update = diff_orthogonal + eta * diff_parallel
141+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
142+
return pred_guided
143+
144+
103145
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104146
def retrieve_timesteps(
105147
scheduler,
@@ -789,6 +831,18 @@ def guidance_scale(self):
789831
def guidance_rescale(self):
790832
return self._guidance_rescale
791833

834+
@property
835+
def adaptive_projected_guidance(self):
836+
return self._adaptive_projected_guidance
837+
838+
@property
839+
def adaptive_projected_guidance_momentum(self):
840+
return self._adaptive_projected_guidance_momentum
841+
842+
@property
843+
def adaptive_projected_guidance_rescale_factor(self):
844+
return self._adaptive_projected_guidance_rescale_factor
845+
792846
@property
793847
def clip_skip(self):
794848
return self._clip_skip
@@ -845,6 +899,9 @@ def __call__(
845899
return_dict: bool = True,
846900
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
847901
guidance_rescale: float = 0.0,
902+
adaptive_projected_guidance: Optional[bool] = None,
903+
adaptive_projected_guidance_momentum: Optional[float] = 0.0,
904+
adaptive_projected_guidance_rescale_factor: Optional[float] = 15.0,
848905
original_size: Optional[Tuple[int, int]] = None,
849906
crops_coords_top_left: Tuple[int, int] = (0, 0),
850907
target_size: Optional[Tuple[int, int]] = None,
@@ -956,6 +1013,13 @@ def __call__(
9561013
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571014
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581015
Guidance rescale factor should fix overexposure when using zero terminal SNR.
1016+
adaptive_projected_guidance (`bool`, *optional*):
1017+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1018+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1019+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
1020+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
1021+
adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
1022+
Momentum value to use with adaptive projected guidance. Use `None` to disable momentum.
9591023
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601024
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611025
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1113,9 @@ def __call__(
10491113

10501114
self._guidance_scale = guidance_scale
10511115
self._guidance_rescale = guidance_rescale
1116+
self._adaptive_projected_guidance = adaptive_projected_guidance
1117+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
1118+
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
10521119
self._clip_skip = clip_skip
10531120
self._cross_attention_kwargs = cross_attention_kwargs
10541121
self._denoising_end = denoising_end
@@ -1181,6 +1248,11 @@ def __call__(
11811248
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
11821249
).to(device=device, dtype=latents.dtype)
11831250

1251+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1252+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1253+
else:
1254+
momentum_buffer = None
1255+
11841256
self._num_timesteps = len(timesteps)
11851257
with self.progress_bar(total=num_inference_steps) as progress_bar:
11861258
for i, t in enumerate(timesteps):
@@ -1209,7 +1281,12 @@ def __call__(
12091281
# perform guidance
12101282
if self.do_classifier_free_guidance:
12111283
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1212-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1284+
if adaptive_projected_guidance:
1285+
noise_pred = normalized_guidance(
1286+
noise_pred_text, noise_pred_uncond, self.guidance_scale, momentum_buffer, norm_threshold=adaptive_projected_guidance_rescale_factor
1287+
)
1288+
else:
1289+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
12131290

12141291
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
12151292
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

0 commit comments

Comments
 (0)