File tree Expand file tree Collapse file tree 6 files changed +48
-0
lines changed
src/diffusers/pipelines/controlnet Expand file tree Collapse file tree 6 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893893 def num_timesteps (self ):
894894 return self ._num_timesteps
895895
896+ @property
897+ def interrupt (self ):
898+ return self ._interrupt
899+
896900 @torch .no_grad ()
897901 @replace_example_docstring (EXAMPLE_DOC_STRING )
898902 def __call__ (
@@ -1089,6 +1093,7 @@ def __call__(
10891093 self ._guidance_scale = guidance_scale
10901094 self ._clip_skip = clip_skip
10911095 self ._cross_attention_kwargs = cross_attention_kwargs
1096+ self ._interrupt = False
10921097
10931098 # 2. Define call parameters
10941099 if prompt is not None and isinstance (prompt , str ):
@@ -1235,6 +1240,9 @@ def __call__(
12351240 is_torch_higher_equal_2_1 = is_torch_version (">=" , "2.1" )
12361241 with self .progress_bar (total = num_inference_steps ) as progress_bar :
12371242 for i , t in enumerate (timesteps ):
1243+ if self .interrupt :
1244+ continue
1245+
12381246 # Relevant thread:
12391247 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
12401248 if (is_unet_compiled and is_controlnet_compiled ) and is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -890,6 +890,10 @@ def cross_attention_kwargs(self):
890890 @property
891891 def num_timesteps (self ):
892892 return self ._num_timesteps
893+
894+ @property
895+ def interrupt (self ):
896+ return self ._interrupt
893897
894898 @torch .no_grad ()
895899 @replace_example_docstring (EXAMPLE_DOC_STRING )
@@ -1081,6 +1085,7 @@ def __call__(
10811085 self ._guidance_scale = guidance_scale
10821086 self ._clip_skip = clip_skip
10831087 self ._cross_attention_kwargs = cross_attention_kwargs
1088+ self ._interrupt = False
10841089
10851090 # 2. Define call parameters
10861091 if prompt is not None and isinstance (prompt , str ):
@@ -1211,6 +1216,9 @@ def __call__(
12111216 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
12121217 with self .progress_bar (total = num_inference_steps ) as progress_bar :
12131218 for i , t in enumerate (timesteps ):
1219+ if self .interrupt :
1220+ continue
1221+
12141222 # expand the latents if we are doing classifier free guidance
12151223 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
12161224 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976976 def num_timesteps (self ):
977977 return self ._num_timesteps
978978
979+ @property
980+ def interrupt (self ):
981+ return self ._interrupt
982+
979983 @torch .no_grad ()
980984 @replace_example_docstring (EXAMPLE_DOC_STRING )
981985 def __call__ (
@@ -1191,6 +1195,7 @@ def __call__(
11911195 self ._guidance_scale = guidance_scale
11921196 self ._clip_skip = clip_skip
11931197 self ._cross_attention_kwargs = cross_attention_kwargs
1198+ self ._interrupt = False
11941199
11951200 # 2. Define call parameters
11961201 if prompt is not None and isinstance (prompt , str ):
@@ -1375,6 +1380,9 @@ def __call__(
13751380 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
13761381 with self .progress_bar (total = num_inference_steps ) as progress_bar :
13771382 for i , t in enumerate (timesteps ):
1383+ if self .interrupt :
1384+ continue
1385+
13781386 # expand the latents if we are doing classifier free guidance
13791387 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
13801388 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -1144,6 +1144,10 @@ def cross_attention_kwargs(self):
11441144 @property
11451145 def num_timesteps (self ):
11461146 return self ._num_timesteps
1147+
1148+ @property
1149+ def interrupt (self ):
1150+ return self ._interrupt
11471151
11481152 @torch .no_grad ()
11491153 @replace_example_docstring (EXAMPLE_DOC_STRING )
@@ -1427,6 +1431,7 @@ def __call__(
14271431 self ._guidance_scale = guidance_scale
14281432 self ._clip_skip = clip_skip
14291433 self ._cross_attention_kwargs = cross_attention_kwargs
1434+ self ._interrupt = False
14301435
14311436 # 2. Define call parameters
14321437 if prompt is not None and isinstance (prompt , str ):
@@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):
16951700
16961701 with self .progress_bar (total = num_inference_steps ) as progress_bar :
16971702 for i , t in enumerate (timesteps ):
1703+ if self .interrupt :
1704+ continue
1705+
16981706 # expand the latents if we are doing classifier free guidance
16991707 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
17001708
Original file line number Diff line number Diff line change @@ -989,6 +989,10 @@ def denoising_end(self):
989989 @property
990990 def num_timesteps (self ):
991991 return self ._num_timesteps
992+
993+ @property
994+ def interrupt (self ):
995+ return self ._interrupt
992996
993997 @torch .no_grad ()
994998 @replace_example_docstring (EXAMPLE_DOC_STRING )
@@ -1245,6 +1249,7 @@ def __call__(
12451249 self ._clip_skip = clip_skip
12461250 self ._cross_attention_kwargs = cross_attention_kwargs
12471251 self ._denoising_end = denoising_end
1252+ self ._interrupt = False
12481253
12491254 # 2. Define call parameters
12501255 if prompt is not None and isinstance (prompt , str ):
@@ -1442,6 +1447,9 @@ def __call__(
14421447 is_torch_higher_equal_2_1 = is_torch_version (">=" , "2.1" )
14431448 with self .progress_bar (total = num_inference_steps ) as progress_bar :
14441449 for i , t in enumerate (timesteps ):
1450+ if self .interrupt :
1451+ continue
1452+
14451453 # Relevant thread:
14461454 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
14471455 if (is_unet_compiled and is_controlnet_compiled ) and is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
10701070 def num_timesteps (self ):
10711071 return self ._num_timesteps
10721072
1073+ @property
1074+ def interrupt (self ):
1075+ return self ._interrupt
1076+
10731077 @torch .no_grad ()
10741078 @replace_example_docstring (EXAMPLE_DOC_STRING )
10751079 def __call__ (
@@ -1338,6 +1342,7 @@ def __call__(
13381342 self ._guidance_scale = guidance_scale
13391343 self ._clip_skip = clip_skip
13401344 self ._cross_attention_kwargs = cross_attention_kwargs
1345+ self ._interrupt = False
13411346
13421347 # 2. Define call parameters
13431348 if prompt is not None and isinstance (prompt , str ):
@@ -1510,6 +1515,9 @@ def __call__(
15101515 num_warmup_steps = len (timesteps ) - num_inference_steps * self .scheduler .order
15111516 with self .progress_bar (total = num_inference_steps ) as progress_bar :
15121517 for i , t in enumerate (timesteps ):
1518+ if self .interrupt :
1519+ continue
1520+
15131521 # expand the latents if we are doing classifier free guidance
15141522 latent_model_input = torch .cat ([latents ] * 2 ) if self .do_classifier_free_guidance else latents
15151523 latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
You can’t perform that action at this time.
0 commit comments