diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md index eae50247c9e5..3c9ad0d89bb4 100644 --- a/examples/community/README_community_scripts.md +++ b/examples/community/README_community_scripts.md @@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks from diffusers.configuration_utils import register_to_config import torch -from typing import Any, Dict, Optional +from typing import Any, Dict, Tuple, Union + + +class SDPromptSchedulingCallback(PipelineCallback): + @register_to_config + def __init__( + self, + encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + cutoff_step_ratio=None, + cutoff_step_index=None, + ): + super().__init__( + cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index + ) + + tensor_inputs = ["prompt_embeds"] + + def callback_fn( + self, pipeline, step_index, timestep, callback_kwargs + ) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + if isinstance(self.config.encoded_prompt, tuple): + prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt + else: + prompt_embeds = self.config.encoded_prompt + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index + if cutoff_step_index is not None + else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + if pipeline.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( @@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( pipeline.safety_checker = None pipeline.requires_safety_checker = False +callback = MultiPipelineCallbacks( + [ + SDPromptSchedulingCallback( + encoded_prompt=pipeline.encode_prompt( + prompt=f"prompt {index}", + negative_prompt=f"negative prompt {index}", + device=pipeline._execution_device, + num_images_per_prompt=1, + # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran + do_classifier_free_guidance=True, + ), + cutoff_step_index=index, + ) for index in range(1, 20) + ] +) + +image = pipeline( + prompt="prompt" + negative_prompt="negative prompt", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=["prompt_embeds"], +).images[0] +torch.cuda.empty_cache() +image.save('image.png') +``` -class SDPromptScheduleCallback(PipelineCallback): +```python +from diffusers import StableDiffusionXLPipeline +from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks +from diffusers.configuration_utils import register_to_config +import torch +from typing import Any, Dict, Tuple, Union + + +class SDXLPromptSchedulingCallback(PipelineCallback): @register_to_config def __init__( self, - prompt: str, - negative_prompt: Optional[str] = None, - num_images_per_prompt: int = 1, - cutoff_step_ratio=1.0, + encoded_prompt: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + add_text_embeds: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + add_time_ids: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + cutoff_step_ratio=None, cutoff_step_index=None, ): super().__init__( cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index ) - tensor_inputs = ["prompt_embeds"] + tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"] def callback_fn( self, pipeline, step_index, timestep, callback_kwargs ) -> Dict[str, Any]: cutoff_step_ratio = self.config.cutoff_step_ratio cutoff_step_index = self.config.cutoff_step_index + if isinstance(self.config.encoded_prompt, tuple): + prompt_embeds, negative_prompt_embeds = self.config.encoded_prompt + else: + prompt_embeds = self.config.encoded_prompt + if isinstance(self.config.add_text_embeds, tuple): + add_text_embeds, negative_add_text_embeds = self.config.add_text_embeds + else: + add_text_embeds = self.config.add_text_embeds + if isinstance(self.config.add_time_ids, tuple): + add_time_ids, negative_add_time_ids = self.config.add_time_ids + else: + add_time_ids = self.config.add_time_ids # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio cutoff_step = ( @@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback): ) if step_index == cutoff_step: - prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( - prompt=self.config.prompt, - negative_prompt=self.config.negative_prompt, - device=pipeline._execution_device, - num_images_per_prompt=self.config.num_images_per_prompt, - do_classifier_free_guidance=pipeline.do_classifier_free_guidance, - ) if pipeline.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds]) + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids]) callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + callback_kwargs[self.tensor_inputs[1]] = add_text_embeds + callback_kwargs[self.tensor_inputs[2]] = add_time_ids return callback_kwargs -callback = MultiPipelineCallbacks( - [ - SDPromptScheduleCallback( - prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", - negative_prompt="Deformed, ugly, bad anatomy", - cutoff_step_ratio=0.25, + +pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, +).to("cuda") + +callbacks = [] +for index in range(1, 20): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = pipeline.encode_prompt( + prompt=f"prompt {index}", + negative_prompt=f"prompt {index}", + device=pipeline._execution_device, + num_images_per_prompt=1, + # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran + do_classifier_free_guidance=True, + ) + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + add_time_ids = pipeline._get_add_time_ids( + (1024, 1024), + (0, 0), + (1024, 1024), + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + negative_add_time_ids = pipeline._get_add_time_ids( + (1024, 1024), + (0, 0), + (1024, 1024), + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, + ) + callbacks.append( + SDXLPromptSchedulingCallback( + encoded_prompt=(prompt_embeds, negative_prompt_embeds), + add_text_embeds=(pooled_prompt_embeds, negative_pooled_prompt_embeds), + add_time_ids=(add_time_ids, negative_add_time_ids), + cutoff_step_index=index, ) - ] -) + ) + + +callback = MultiPipelineCallbacks(callbacks) image = pipeline( - prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", - negative_prompt="Deformed, ugly, bad anatomy", + prompt="prompt", + negative_prompt="negative prompt", callback_on_step_end=callback, - callback_on_step_end_tensor_inputs=["prompt_embeds"], + callback_on_step_end_tensor_inputs=[ + "prompt_embeds", + "add_text_embeds", + "add_time_ids", + ], ).images[0] -torch.cuda.empty_cache() -image.save('image.png') ``` diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a4757ac2f336..d83fa6201117 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -237,11 +237,8 @@ class StableDiffusionXLPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "negative_add_time_ids", ] def __init__( @@ -1243,13 +1240,8 @@ def __call__( latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 50688ddb1cb8..126f25a41adc 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", ] def __init__( @@ -1438,13 +1435,8 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c7c706350e8e..a378ae65eb30 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline( _callback_tensor_inputs = [ "latents", "prompt_embeds", - "negative_prompt_embeds", "add_text_embeds", "add_time_ids", - "negative_pooled_prompt_embeds", - "add_neg_time_ids", "mask", "masked_image_latents", ] @@ -1671,13 +1668,8 @@ def denoising_value_valid(dnv): latents = callback_outputs.pop("latents", latents) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) - negative_pooled_prompt_embeds = callback_outputs.pop( - "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds - ) add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) - add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids) mask = callback_outputs.pop("mask", mask) masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)