From b6aea1e724a4c3018eab260c7421bcc846e9dcde Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 19 Oct 2024 17:02:30 +0100 Subject: [PATCH] Add prompt scheduling callback to community scripts --- .../community/README_community_scripts.md | 85 ++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/examples/community/README_community_scripts.md b/examples/community/README_community_scripts.md index 8432b4e82c9f..2c2f549a2bd5 100644 --- a/examples/community/README_community_scripts.md +++ b/examples/community/README_community_scripts.md @@ -8,6 +8,7 @@ If a community script doesn't work as expected, please open an issue and ping th |:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:| | Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)| | asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)| +| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)| ## Example usages @@ -229,4 +230,86 @@ seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False) torch.cuda.empty_cache() image.save('image.png') -``` \ No newline at end of file +``` + +### Prompt Scheduling callback + +Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing) + +```python +from diffusers import StableDiffusionPipeline +from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks +from diffusers.configuration_utils import register_to_config +import torch +from typing import Any, Dict, Optional + + +pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained( + "stable-diffusion-v1-5/stable-diffusion-v1-5", + torch_dtype=torch.float16, + variant="fp16", + use_safetensors=True, +).to("cuda") +pipeline.safety_checker = None +pipeline.requires_safety_checker = False + + +class SDPromptScheduleCallback(PipelineCallback): + @register_to_config + def __init__( + self, + prompt: str, + negative_prompt: Optional[str] = None, + num_images_per_prompt: int = 1, + cutoff_step_ratio=1.0, + cutoff_step_index=None, + ): + super().__init__( + cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index + ) + + tensor_inputs = ["prompt_embeds"] + + def callback_fn( + self, pipeline, step_index, timestep, callback_kwargs + ) -> Dict[str, Any]: + cutoff_step_ratio = self.config.cutoff_step_ratio + cutoff_step_index = self.config.cutoff_step_index + + # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio + cutoff_step = ( + cutoff_step_index + if cutoff_step_index is not None + else int(pipeline.num_timesteps * cutoff_step_ratio) + ) + + if step_index == cutoff_step: + prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt( + prompt=self.config.prompt, + negative_prompt=self.config.negative_prompt, + device=pipeline._execution_device, + num_images_per_prompt=self.config.num_images_per_prompt, + do_classifier_free_guidance=pipeline.do_classifier_free_guidance, + ) + if pipeline.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + callback_kwargs[self.tensor_inputs[0]] = prompt_embeds + return callback_kwargs + +callback = MultiPipelineCallbacks( + [ + SDPromptScheduleCallback( + prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", + negative_prompt="Deformed, ugly, bad anatomy", + cutoff_step_ratio=0.25, + ) + ] +) + +image = pipeline( + prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski", + negative_prompt="Deformed, ugly, bad anatomy", + callback_on_step_end=callback, + callback_on_step_end_tensor_inputs=["prompt_embeds"], +).images[0] +```