Skip to content

Commit 2026ec0

Browse files
DN6stevhliupatrickvonplatensayakpaul
authored
Interruptable Pipelines (#5867)
* add interruptable pipelines * add tests * updatemsmq * add interrupt property * make fix copies * Revert "make fix copies" This reverts commit 914b353. * add docs * add tutorial * Update docs/source/en/tutorials/interrupting_diffusion_process.md Co-authored-by: Steven Liu <[email protected]> * Update docs/source/en/tutorials/interrupting_diffusion_process.md Co-authored-by: Steven Liu <[email protected]> * update * fix quality issues * fix * update --------- Co-authored-by: Steven Liu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 3706aa3 commit 2026ec0

13 files changed

+422
-0
lines changed

docs/source/en/using-diffusers/callback.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,42 @@ With callbacks, you can implement features such as dynamic CFG without having to
6363
🤗 Diffusers currently only supports `callback_on_step_end`, but feel free to open a [feature request](https://github.com/huggingface/diffusers/issues/new/choose) if you have a cool use-case and require a callback function with a different execution point!
6464

6565
</Tip>
66+
67+
68+
## Using Callbacks to interrupt the Diffusion Process
69+
70+
The following Pipelines support interrupting the diffusion process via callback
71+
72+
- [StableDiffusionPipeline](../api/pipelines/stable_diffusion/overview.md)
73+
- [StableDiffusionImg2ImgPipeline](..api/pipelines/stable_diffusion/img2img.md)
74+
- [StableDiffusionInpaintPipeline](..api/pipelines/stable_diffusion/inpaint.md)
75+
- [StableDiffusionXLPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
76+
- [StableDiffusionXLImg2ImgPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
77+
- [StableDiffusionXLInpaintPipeline](../api/pipelines/stable_diffusion/stable_diffusion_xl.md)
78+
79+
Interrupting the diffusion process is particularly useful when building UIs that work with Diffusers because it allows users to stop the generation process if they're unhappy with the intermediate results. You can incorporate this into your pipeline with a callback.
80+
81+
This callback function should take the following arguments: `pipe`, `i`, `t`, and `callback_kwargs` (this must be returned). Set the pipeline's `_interrupt` attribute to `True` to stop the diffusion process after a certain number of steps. You are also free to implement your own custom stopping logic inside the callback.
82+
83+
In this example, the diffusion process is stopped after 10 steps even though `num_inference_steps` is set to 50.
84+
85+
```python
86+
from diffusers import StableDiffusionPipeline
87+
88+
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
89+
pipe.enable_model_cpu_offload()
90+
num_inference_steps = 50
91+
92+
def interrupt_callback(pipe, i, t, callback_kwargs):
93+
stop_idx = 10
94+
if i == stop_idx:
95+
pipe._interrupt = True
96+
97+
return callback_kwargs
98+
99+
pipe(
100+
"A photo of a cat",
101+
num_inference_steps=num_inference_steps,
102+
callback_on_step_end=interrupt_callback,
103+
)
104+
```

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,10 @@ def cross_attention_kwargs(self):
768768
def num_timesteps(self):
769769
return self._num_timesteps
770770

771+
@property
772+
def interrupt(self):
773+
return self._interrupt
774+
771775
@torch.no_grad()
772776
@replace_example_docstring(EXAMPLE_DOC_STRING)
773777
def __call__(
@@ -909,6 +913,7 @@ def __call__(
909913
self._guidance_rescale = guidance_rescale
910914
self._clip_skip = clip_skip
911915
self._cross_attention_kwargs = cross_attention_kwargs
916+
self._interrupt = False
912917

913918
# 2. Define call parameters
914919
if prompt is not None and isinstance(prompt, str):
@@ -986,6 +991,9 @@ def __call__(
986991
self._num_timesteps = len(timesteps)
987992
with self.progress_bar(total=num_inference_steps) as progress_bar:
988993
for i, t in enumerate(timesteps):
994+
if self.interrupt:
995+
continue
996+
989997
# expand the latents if we are doing classifier free guidance
990998
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
991999
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,10 @@ def cross_attention_kwargs(self):
832832
def num_timesteps(self):
833833
return self._num_timesteps
834834

835+
@property
836+
def interrupt(self):
837+
return self._interrupt
838+
835839
@torch.no_grad()
836840
@replace_example_docstring(EXAMPLE_DOC_STRING)
837841
def __call__(
@@ -963,6 +967,7 @@ def __call__(
963967
self._guidance_scale = guidance_scale
964968
self._clip_skip = clip_skip
965969
self._cross_attention_kwargs = cross_attention_kwargs
970+
self._interrupt = False
966971

967972
# 2. Define call parameters
968973
if prompt is not None and isinstance(prompt, str):
@@ -1041,6 +1046,9 @@ def __call__(
10411046
self._num_timesteps = len(timesteps)
10421047
with self.progress_bar(total=num_inference_steps) as progress_bar:
10431048
for i, t in enumerate(timesteps):
1049+
if self.interrupt:
1050+
continue
1051+
10441052
# expand the latents if we are doing classifier free guidance
10451053
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
10461054
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,10 @@ def cross_attention_kwargs(self):
958958
def num_timesteps(self):
959959
return self._num_timesteps
960960

961+
@property
962+
def interrupt(self):
963+
return self._interrupt
964+
961965
@torch.no_grad()
962966
def __call__(
963967
self,
@@ -1144,6 +1148,7 @@ def __call__(
11441148
self._guidance_scale = guidance_scale
11451149
self._clip_skip = clip_skip
11461150
self._cross_attention_kwargs = cross_attention_kwargs
1151+
self._interrupt = False
11471152

11481153
# 2. Define call parameters
11491154
if prompt is not None and isinstance(prompt, str):
@@ -1288,6 +1293,9 @@ def __call__(
12881293
self._num_timesteps = len(timesteps)
12891294
with self.progress_bar(total=num_inference_steps) as progress_bar:
12901295
for i, t in enumerate(timesteps):
1296+
if self.interrupt:
1297+
continue
1298+
12911299
# expand the latents if we are doing classifier free guidance
12921300
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12931301

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,10 @@ def denoising_end(self):
849849
def num_timesteps(self):
850850
return self._num_timesteps
851851

852+
@property
853+
def interrupt(self):
854+
return self._interrupt
855+
852856
@torch.no_grad()
853857
@replace_example_docstring(EXAMPLE_DOC_STRING)
854858
def __call__(
@@ -1067,6 +1071,7 @@ def __call__(
10671071
self._clip_skip = clip_skip
10681072
self._cross_attention_kwargs = cross_attention_kwargs
10691073
self._denoising_end = denoising_end
1074+
self._interrupt = False
10701075

10711076
# 2. Define call parameters
10721077
if prompt is not None and isinstance(prompt, str):
@@ -1196,6 +1201,9 @@ def __call__(
11961201
self._num_timesteps = len(timesteps)
11971202
with self.progress_bar(total=num_inference_steps) as progress_bar:
11981203
for i, t in enumerate(timesteps):
1204+
if self.interrupt:
1205+
continue
1206+
11991207
# expand the latents if we are doing classifier free guidance
12001208
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12011209

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,10 @@ def denoising_start(self):
990990
def num_timesteps(self):
991991
return self._num_timesteps
992992

993+
@property
994+
def interrupt(self):
995+
return self._interrupt
996+
993997
@torch.no_grad()
994998
@replace_example_docstring(EXAMPLE_DOC_STRING)
995999
def __call__(
@@ -1221,6 +1225,7 @@ def __call__(
12211225
self._cross_attention_kwargs = cross_attention_kwargs
12221226
self._denoising_end = denoising_end
12231227
self._denoising_start = denoising_start
1228+
self._interrupt = False
12241229

12251230
# 2. Define call parameters
12261231
if prompt is not None and isinstance(prompt, str):
@@ -1376,6 +1381,9 @@ def denoising_value_valid(dnv):
13761381
self._num_timesteps = len(timesteps)
13771382
with self.progress_bar(total=num_inference_steps) as progress_bar:
13781383
for i, t in enumerate(timesteps):
1384+
if self.interrupt:
1385+
continue
1386+
13791387
# expand the latents if we are doing classifier free guidance
13801388
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
13811389

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,6 +1210,10 @@ def denoising_start(self):
12101210
def num_timesteps(self):
12111211
return self._num_timesteps
12121212

1213+
@property
1214+
def interrupt(self):
1215+
return self._interrupt
1216+
12131217
@torch.no_grad()
12141218
@replace_example_docstring(EXAMPLE_DOC_STRING)
12151219
def __call__(
@@ -1462,6 +1466,7 @@ def __call__(
14621466
self._cross_attention_kwargs = cross_attention_kwargs
14631467
self._denoising_end = denoising_end
14641468
self._denoising_start = denoising_start
1469+
self._interrupt = False
14651470

14661471
# 2. Define call parameters
14671472
if prompt is not None and isinstance(prompt, str):
@@ -1684,6 +1689,8 @@ def denoising_value_valid(dnv):
16841689
self._num_timesteps = len(timesteps)
16851690
with self.progress_bar(total=num_inference_steps) as progress_bar:
16861691
for i, t in enumerate(timesteps):
1692+
if self.interrupt:
1693+
continue
16871694
# expand the latents if we are doing classifier free guidance
16881695
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
16891696

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,58 @@ def test_fused_qkv_projections(self):
692692
original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2
693693
), "Original outputs should match when fused QKV projections are disabled."
694694

695+
def test_pipeline_interrupt(self):
696+
components = self.get_dummy_components()
697+
sd_pipe = StableDiffusionPipeline(**components)
698+
sd_pipe = sd_pipe.to(torch_device)
699+
sd_pipe.set_progress_bar_config(disable=None)
700+
701+
prompt = "hey"
702+
num_inference_steps = 3
703+
704+
# store intermediate latents from the generation process
705+
class PipelineState:
706+
def __init__(self):
707+
self.state = []
708+
709+
def apply(self, pipe, i, t, callback_kwargs):
710+
self.state.append(callback_kwargs["latents"])
711+
return callback_kwargs
712+
713+
pipe_state = PipelineState()
714+
sd_pipe(
715+
prompt,
716+
num_inference_steps=num_inference_steps,
717+
output_type="np",
718+
generator=torch.Generator("cpu").manual_seed(0),
719+
callback_on_step_end=pipe_state.apply,
720+
).images
721+
722+
# interrupt generation at step index
723+
interrupt_step_idx = 1
724+
725+
def callback_on_step_end(pipe, i, t, callback_kwargs):
726+
if i == interrupt_step_idx:
727+
pipe._interrupt = True
728+
729+
return callback_kwargs
730+
731+
output_interrupted = sd_pipe(
732+
prompt,
733+
num_inference_steps=num_inference_steps,
734+
output_type="latent",
735+
generator=torch.Generator("cpu").manual_seed(0),
736+
callback_on_step_end=callback_on_step_end,
737+
).images
738+
739+
# fetch intermediate latents at the interrupted step
740+
# from the completed generation process
741+
intermediate_latent = pipe_state.state[interrupt_step_idx]
742+
743+
# compare the intermediate latent to the output of the interrupted process
744+
# they should be the same
745+
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
746+
695747

696748
@slow
697749
@require_torch_gpu

tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,62 @@ def test_inference_batch_single_identical(self):
320320
def test_float16_inference(self):
321321
super().test_float16_inference(expected_max_diff=5e-1)
322322

323+
def test_pipeline_interrupt(self):
324+
components = self.get_dummy_components()
325+
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
326+
sd_pipe = sd_pipe.to(torch_device)
327+
sd_pipe.set_progress_bar_config(disable=None)
328+
329+
inputs = self.get_dummy_inputs(torch_device)
330+
331+
prompt = "hey"
332+
num_inference_steps = 3
333+
334+
# store intermediate latents from the generation process
335+
class PipelineState:
336+
def __init__(self):
337+
self.state = []
338+
339+
def apply(self, pipe, i, t, callback_kwargs):
340+
self.state.append(callback_kwargs["latents"])
341+
return callback_kwargs
342+
343+
pipe_state = PipelineState()
344+
sd_pipe(
345+
prompt,
346+
image=inputs["image"],
347+
num_inference_steps=num_inference_steps,
348+
output_type="np",
349+
generator=torch.Generator("cpu").manual_seed(0),
350+
callback_on_step_end=pipe_state.apply,
351+
).images
352+
353+
# interrupt generation at step index
354+
interrupt_step_idx = 1
355+
356+
def callback_on_step_end(pipe, i, t, callback_kwargs):
357+
if i == interrupt_step_idx:
358+
pipe._interrupt = True
359+
360+
return callback_kwargs
361+
362+
output_interrupted = sd_pipe(
363+
prompt,
364+
image=inputs["image"],
365+
num_inference_steps=num_inference_steps,
366+
output_type="latent",
367+
generator=torch.Generator("cpu").manual_seed(0),
368+
callback_on_step_end=callback_on_step_end,
369+
).images
370+
371+
# fetch intermediate latents at the interrupted step
372+
# from the completed generation process
373+
intermediate_latent = pipe_state.state[interrupt_step_idx]
374+
375+
# compare the intermediate latent to the output of the interrupted process
376+
# they should be the same
377+
assert torch.allclose(intermediate_latent, output_interrupted, atol=1e-4)
378+
323379

324380
@slow
325381
@require_torch_gpu

0 commit comments

Comments
 (0)