Skip to content

Commit b6aea1e

Browse files
committed
Add prompt scheduling callback to community scripts
1 parent 5d3e7bd commit b6aea1e

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

examples/community/README_community_scripts.md

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ If a community script doesn't work as expected, please open an issue and ping th
88
|:--------------------------------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------------------------------------------------------------------------------------------|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------:|
99
| Using IP-Adapter with negative noise | Using negative noise with IP-adapter to better control the generation (see the [original post](https://github.com/huggingface/diffusers/discussions/7167) on the forum for more details) | [IP-Adapter Negative Noise](#ip-adapter-negative-noise) | | [Álvaro Somoza](https://github.com/asomoza)|
1010
| asymmetric tiling |configure seamless image tiling independently for the X and Y axes | [Asymmetric Tiling](#asymmetric-tiling ) | | [alexisrolland](https://github.com/alexisrolland)|
11+
| Prompt scheduling callback |Allows changing prompts during a generation | [Prompt Scheduling](#prompt-scheduling ) | | [hlky](https://github.com/hlky)|
1112

1213

1314
## Example usages
@@ -229,4 +230,86 @@ seamless_tiling(pipeline=pipeline, x_axis=False, y_axis=False)
229230

230231
torch.cuda.empty_cache()
231232
image.save('image.png')
232-
```
233+
```
234+
235+
### Prompt Scheduling callback
236+
237+
Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing)
238+
239+
```python
240+
from diffusers import StableDiffusionPipeline
241+
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
242+
from diffusers.configuration_utils import register_to_config
243+
import torch
244+
from typing import Any, Dict, Optional
245+
246+
247+
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
248+
"stable-diffusion-v1-5/stable-diffusion-v1-5",
249+
torch_dtype=torch.float16,
250+
variant="fp16",
251+
use_safetensors=True,
252+
).to("cuda")
253+
pipeline.safety_checker = None
254+
pipeline.requires_safety_checker = False
255+
256+
257+
class SDPromptScheduleCallback(PipelineCallback):
258+
@register_to_config
259+
def __init__(
260+
self,
261+
prompt: str,
262+
negative_prompt: Optional[str] = None,
263+
num_images_per_prompt: int = 1,
264+
cutoff_step_ratio=1.0,
265+
cutoff_step_index=None,
266+
):
267+
super().__init__(
268+
cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index
269+
)
270+
271+
tensor_inputs = ["prompt_embeds"]
272+
273+
def callback_fn(
274+
self, pipeline, step_index, timestep, callback_kwargs
275+
) -> Dict[str, Any]:
276+
cutoff_step_ratio = self.config.cutoff_step_ratio
277+
cutoff_step_index = self.config.cutoff_step_index
278+
279+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280+
cutoff_step = (
281+
cutoff_step_index
282+
if cutoff_step_index is not None
283+
else int(pipeline.num_timesteps * cutoff_step_ratio)
284+
)
285+
286+
if step_index == cutoff_step:
287+
prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
288+
prompt=self.config.prompt,
289+
negative_prompt=self.config.negative_prompt,
290+
device=pipeline._execution_device,
291+
num_images_per_prompt=self.config.num_images_per_prompt,
292+
do_classifier_free_guidance=pipeline.do_classifier_free_guidance,
293+
)
294+
if pipeline.do_classifier_free_guidance:
295+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
296+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
297+
return callback_kwargs
298+
299+
callback = MultiPipelineCallbacks(
300+
[
301+
SDPromptScheduleCallback(
302+
prompt="Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
303+
negative_prompt="Deformed, ugly, bad anatomy",
304+
cutoff_step_ratio=0.25,
305+
)
306+
]
307+
)
308+
309+
image = pipeline(
310+
prompt="Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski",
311+
negative_prompt="Deformed, ugly, bad anatomy",
312+
callback_on_step_end=callback,
313+
callback_on_step_end_tensor_inputs=["prompt_embeds"],
314+
).images[0]
315+
```

0 commit comments

Comments
 (0)