@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
97
97
98
98
class SDXLCFGCutoffCallback (PipelineCallback ):
99
99
"""
100
- Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101
- `cutoff_step_index`), this callback will disable the CFG.
100
+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101
+ `cutoff_step_ratio` or ` cutoff_step_index`), this callback will disable the CFG.
102
102
103
103
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104
104
"""
105
105
106
- tensor_inputs = ["prompt_embeds" , "add_text_embeds" , "add_time_ids" ]
106
+ tensor_inputs = [
107
+ "prompt_embeds" ,
108
+ "add_text_embeds" ,
109
+ "add_time_ids" ,
110
+ ]
107
111
108
112
def callback_fn (self , pipeline , step_index , timestep , callback_kwargs ) -> Dict [str , Any ]:
109
113
cutoff_step_ratio = self .config .cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
129
133
callback_kwargs [self .tensor_inputs [0 ]] = prompt_embeds
130
134
callback_kwargs [self .tensor_inputs [1 ]] = add_text_embeds
131
135
callback_kwargs [self .tensor_inputs [2 ]] = add_time_ids
136
+
137
+ return callback_kwargs
138
+
139
+
140
+ class SDXLControlnetCFGCutoffCallback (PipelineCallback ):
141
+ """
142
+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143
+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144
+
145
+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146
+ """
147
+
148
+ tensor_inputs = [
149
+ "prompt_embeds" ,
150
+ "add_text_embeds" ,
151
+ "add_time_ids" ,
152
+ "image" ,
153
+ ]
154
+
155
+ def callback_fn (self , pipeline , step_index , timestep , callback_kwargs ) -> Dict [str , Any ]:
156
+ cutoff_step_ratio = self .config .cutoff_step_ratio
157
+ cutoff_step_index = self .config .cutoff_step_index
158
+
159
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160
+ cutoff_step = (
161
+ cutoff_step_index if cutoff_step_index is not None else int (pipeline .num_timesteps * cutoff_step_ratio )
162
+ )
163
+
164
+ if step_index == cutoff_step :
165
+ prompt_embeds = callback_kwargs [self .tensor_inputs [0 ]]
166
+ prompt_embeds = prompt_embeds [- 1 :] # "-1" denotes the embeddings for conditional text tokens.
167
+
168
+ add_text_embeds = callback_kwargs [self .tensor_inputs [1 ]]
169
+ add_text_embeds = add_text_embeds [- 1 :] # "-1" denotes the embeddings for conditional pooled text tokens
170
+
171
+ add_time_ids = callback_kwargs [self .tensor_inputs [2 ]]
172
+ add_time_ids = add_time_ids [- 1 :] # "-1" denotes the embeddings for conditional added time vector
173
+
174
+ # For Controlnet
175
+ image = callback_kwargs [self .tensor_inputs [3 ]]
176
+ image = image [- 1 :]
177
+
178
+ pipeline ._guidance_scale = 0.0
179
+
180
+ callback_kwargs [self .tensor_inputs [0 ]] = prompt_embeds
181
+ callback_kwargs [self .tensor_inputs [1 ]] = add_text_embeds
182
+ callback_kwargs [self .tensor_inputs [2 ]] = add_time_ids
183
+ callback_kwargs [self .tensor_inputs [3 ]] = image
184
+
132
185
return callback_kwargs
133
186
134
187
0 commit comments