@@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
9797
9898class SDXLCFGCutoffCallback (PipelineCallback ):
9999 """
100- Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
101- `cutoff_step_index`), this callback will disable the CFG.
100+ Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
101+ `cutoff_step_ratio` or ` cutoff_step_index`), this callback will disable the CFG.
102102
103103 Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
104104 """
105105
106- tensor_inputs = ["prompt_embeds" , "add_text_embeds" , "add_time_ids" ]
106+ tensor_inputs = [
107+ "prompt_embeds" ,
108+ "add_text_embeds" ,
109+ "add_time_ids" ,
110+ ]
107111
108112 def callback_fn (self , pipeline , step_index , timestep , callback_kwargs ) -> Dict [str , Any ]:
109113 cutoff_step_ratio = self .config .cutoff_step_ratio
@@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
129133 callback_kwargs [self .tensor_inputs [0 ]] = prompt_embeds
130134 callback_kwargs [self .tensor_inputs [1 ]] = add_text_embeds
131135 callback_kwargs [self .tensor_inputs [2 ]] = add_time_ids
136+
137+ return callback_kwargs
138+
139+
140+ class SDXLControlnetCFGCutoffCallback (PipelineCallback ):
141+ """
142+ Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
143+ `cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
144+
145+ Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
146+ """
147+
148+ tensor_inputs = [
149+ "prompt_embeds" ,
150+ "add_text_embeds" ,
151+ "add_time_ids" ,
152+ "image" ,
153+ ]
154+
155+ def callback_fn (self , pipeline , step_index , timestep , callback_kwargs ) -> Dict [str , Any ]:
156+ cutoff_step_ratio = self .config .cutoff_step_ratio
157+ cutoff_step_index = self .config .cutoff_step_index
158+
159+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
160+ cutoff_step = (
161+ cutoff_step_index if cutoff_step_index is not None else int (pipeline .num_timesteps * cutoff_step_ratio )
162+ )
163+
164+ if step_index == cutoff_step :
165+ prompt_embeds = callback_kwargs [self .tensor_inputs [0 ]]
166+ prompt_embeds = prompt_embeds [- 1 :] # "-1" denotes the embeddings for conditional text tokens.
167+
168+ add_text_embeds = callback_kwargs [self .tensor_inputs [1 ]]
169+ add_text_embeds = add_text_embeds [- 1 :] # "-1" denotes the embeddings for conditional pooled text tokens
170+
171+ add_time_ids = callback_kwargs [self .tensor_inputs [2 ]]
172+ add_time_ids = add_time_ids [- 1 :] # "-1" denotes the embeddings for conditional added time vector
173+
174+ # For Controlnet
175+ image = callback_kwargs [self .tensor_inputs [3 ]]
176+ image = image [- 1 :]
177+
178+ pipeline ._guidance_scale = 0.0
179+
180+ callback_kwargs [self .tensor_inputs [0 ]] = prompt_embeds
181+ callback_kwargs [self .tensor_inputs [1 ]] = add_text_embeds
182+ callback_kwargs [self .tensor_inputs [2 ]] = add_time_ids
183+ callback_kwargs [self .tensor_inputs [3 ]] = image
184+
132185 return callback_kwargs
133186
134187
0 commit comments