Skip to content

Commit ab1b7b2

Browse files
authored
[Official callbacks] SDXL Controlnet CFG Cutoff (#9311)
* initial proposal * style
1 parent 9366c8f commit ab1b7b2

File tree

2 files changed

+58
-3
lines changed

2 files changed

+58
-3
lines changed

src/diffusers/callbacks.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
9797

9898
class SDXLCFGCutoffCallback(PipelineCallback):
9999
"""
100-
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101-
`cutoff_step_index`), this callback will disable the CFG.
100+
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
102102
103103
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104104
"""
105105

106-
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
106+
tensor_inputs = [
107+
"prompt_embeds",
108+
"add_text_embeds",
109+
"add_time_ids",
110+
]
107111

108112
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
109113
cutoff_step_ratio = self.config.cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
129133
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
130134
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
131135
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
136+
137+
return callback_kwargs
138+
139+
140+
class SDXLControlnetCFGCutoffCallback(PipelineCallback):
141+
"""
142+
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144+
145+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146+
"""
147+
148+
tensor_inputs = [
149+
"prompt_embeds",
150+
"add_text_embeds",
151+
"add_time_ids",
152+
"image",
153+
]
154+
155+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
156+
cutoff_step_ratio = self.config.cutoff_step_ratio
157+
cutoff_step_index = self.config.cutoff_step_index
158+
159+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160+
cutoff_step = (
161+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
162+
)
163+
164+
if step_index == cutoff_step:
165+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
166+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
167+
168+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
169+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
170+
171+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
172+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
173+
174+
# For Controlnet
175+
image = callback_kwargs[self.tensor_inputs[3]]
176+
image = image[-1:]
177+
178+
pipeline._guidance_scale = 0.0
179+
180+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
181+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
182+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
183+
callback_kwargs[self.tensor_inputs[3]] = image
184+
132185
return callback_kwargs
133186

134187

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
242242
"add_time_ids",
243243
"negative_pooled_prompt_embeds",
244244
"negative_add_time_ids",
245+
"image",
245246
]
246247

247248
def __init__(
@@ -1540,6 +1541,7 @@ def __call__(
15401541
)
15411542
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
15421543
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1544+
image = callback_outputs.pop("image", image)
15431545

15441546
# call the callback, if provided
15451547
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

0 commit comments

Comments
 (0)