diff --git a/src/diffusers/callbacks.py b/src/diffusers/callbacks.py index 4b8b15368c47..2a08f091d9f3 100644 --- a/src/diffusers/callbacks.py +++ b/src/diffusers/callbacks.py @@ -207,3 +207,38 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s if step_index == cutoff_step: pipeline.set_ip_adapter_scale(0.0) return callback_kwargs + + +class SD3CFGCutoffCallback(PipelineCallback): + """ + Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or + `cutoff_step_index`), this callback will disable the CFG. + + Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step. + """ + + tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"] + + def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds = callback_kwargs[self.tensor_inputs[0]] + prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens. + + pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]] + pooled_prompt_embeds = pooled_prompt_embeds[ + -1: + ] # "-1" denotes the embeddings for conditional pooled text tokens. + + pipeline._guidance_scale = 0.0 + + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds + return callback_kwargs diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index ad4c4b09172e..afee3f61e972 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -25,6 +25,7 @@ T5TokenizerFast, ) +from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL @@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae" _optional_components = ["image_encoder", "feature_extractor"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"] def __init__( self, @@ -923,6 +924,9 @@ def __call__( height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, @@ -1109,10 +1113,7 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) + pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):