File tree Expand file tree Collapse file tree 6 files changed +48
-0
lines changed 
src/diffusers/pipelines/controlnet Expand file tree Collapse file tree 6 files changed +48
-0
lines changed Original file line number Diff line number Diff line change @@ -893,6 +893,10 @@ def cross_attention_kwargs(self):
893893    def  num_timesteps (self ):
894894        return  self ._num_timesteps 
895895
896+     @property  
897+     def  interrupt (self ):
898+         return  self ._interrupt 
899+ 
896900    @torch .no_grad () 
897901    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
898902    def  __call__ (
@@ -1089,6 +1093,7 @@ def __call__(
10891093        self ._guidance_scale  =  guidance_scale 
10901094        self ._clip_skip  =  clip_skip 
10911095        self ._cross_attention_kwargs  =  cross_attention_kwargs 
1096+         self ._interrupt  =  False 
10921097
10931098        # 2. Define call parameters 
10941099        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1235,6 +1240,9 @@ def __call__(
12351240        is_torch_higher_equal_2_1  =  is_torch_version (">=" , "2.1" )
12361241        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
12371242            for  i , t  in  enumerate (timesteps ):
1243+                 if  self .interrupt :
1244+                     continue 
1245+ 
12381246                # Relevant thread: 
12391247                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 
12401248                if  (is_unet_compiled  and  is_controlnet_compiled ) and  is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -891,6 +891,10 @@ def cross_attention_kwargs(self):
891891    def  num_timesteps (self ):
892892        return  self ._num_timesteps 
893893
894+     @property  
895+     def  interrupt (self ):
896+         return  self ._interrupt 
897+ 
894898    @torch .no_grad () 
895899    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
896900    def  __call__ (
@@ -1081,6 +1085,7 @@ def __call__(
10811085        self ._guidance_scale  =  guidance_scale 
10821086        self ._clip_skip  =  clip_skip 
10831087        self ._cross_attention_kwargs  =  cross_attention_kwargs 
1088+         self ._interrupt  =  False 
10841089
10851090        # 2. Define call parameters 
10861091        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1211,6 +1216,9 @@ def __call__(
12111216        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
12121217        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
12131218            for  i , t  in  enumerate (timesteps ):
1219+                 if  self .interrupt :
1220+                     continue 
1221+ 
12141222                # expand the latents if we are doing classifier free guidance 
12151223                latent_model_input  =  torch .cat ([latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
12161224                latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -976,6 +976,10 @@ def cross_attention_kwargs(self):
976976    def  num_timesteps (self ):
977977        return  self ._num_timesteps 
978978
979+     @property  
980+     def  interrupt (self ):
981+         return  self ._interrupt 
982+ 
979983    @torch .no_grad () 
980984    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
981985    def  __call__ (
@@ -1191,6 +1195,7 @@ def __call__(
11911195        self ._guidance_scale  =  guidance_scale 
11921196        self ._clip_skip  =  clip_skip 
11931197        self ._cross_attention_kwargs  =  cross_attention_kwargs 
1198+         self ._interrupt  =  False 
11941199
11951200        # 2. Define call parameters 
11961201        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1375,6 +1380,9 @@ def __call__(
13751380        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
13761381        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
13771382            for  i , t  in  enumerate (timesteps ):
1383+                 if  self .interrupt :
1384+                     continue 
1385+ 
13781386                # expand the latents if we are doing classifier free guidance 
13791387                latent_model_input  =  torch .cat ([latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
13801388                latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
Original file line number Diff line number Diff line change @@ -1145,6 +1145,10 @@ def cross_attention_kwargs(self):
11451145    def  num_timesteps (self ):
11461146        return  self ._num_timesteps 
11471147
1148+     @property  
1149+     def  interrupt (self ):
1150+         return  self ._interrupt 
1151+ 
11481152    @torch .no_grad () 
11491153    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
11501154    def  __call__ (
@@ -1427,6 +1431,7 @@ def __call__(
14271431        self ._guidance_scale  =  guidance_scale 
14281432        self ._clip_skip  =  clip_skip 
14291433        self ._cross_attention_kwargs  =  cross_attention_kwargs 
1434+         self ._interrupt  =  False 
14301435
14311436        # 2. Define call parameters 
14321437        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1695,6 +1700,9 @@ def denoising_value_valid(dnv):
16951700
16961701        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
16971702            for  i , t  in  enumerate (timesteps ):
1703+                 if  self .interrupt :
1704+                     continue 
1705+ 
16981706                # expand the latents if we are doing classifier free guidance 
16991707                latent_model_input  =  torch .cat ([latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
17001708
Original file line number Diff line number Diff line change @@ -990,6 +990,10 @@ def denoising_end(self):
990990    def  num_timesteps (self ):
991991        return  self ._num_timesteps 
992992
993+     @property  
994+     def  interrupt (self ):
995+         return  self ._interrupt 
996+ 
993997    @torch .no_grad () 
994998    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
995999    def  __call__ (
@@ -1245,6 +1249,7 @@ def __call__(
12451249        self ._clip_skip  =  clip_skip 
12461250        self ._cross_attention_kwargs  =  cross_attention_kwargs 
12471251        self ._denoising_end  =  denoising_end 
1252+         self ._interrupt  =  False 
12481253
12491254        # 2. Define call parameters 
12501255        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1442,6 +1447,9 @@ def __call__(
14421447        is_torch_higher_equal_2_1  =  is_torch_version (">=" , "2.1" )
14431448        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
14441449            for  i , t  in  enumerate (timesteps ):
1450+                 if  self .interrupt :
1451+                     continue 
1452+ 
14451453                # Relevant thread: 
14461454                # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 
14471455                if  (is_unet_compiled  and  is_controlnet_compiled ) and  is_torch_higher_equal_2_1 :
Original file line number Diff line number Diff line change @@ -1070,6 +1070,10 @@ def cross_attention_kwargs(self):
10701070    def  num_timesteps (self ):
10711071        return  self ._num_timesteps 
10721072
1073+     @property  
1074+     def  interrupt (self ):
1075+         return  self ._interrupt 
1076+ 
10731077    @torch .no_grad () 
10741078    @replace_example_docstring (EXAMPLE_DOC_STRING ) 
10751079    def  __call__ (
@@ -1338,6 +1342,7 @@ def __call__(
13381342        self ._guidance_scale  =  guidance_scale 
13391343        self ._clip_skip  =  clip_skip 
13401344        self ._cross_attention_kwargs  =  cross_attention_kwargs 
1345+         self ._interrupt  =  False 
13411346
13421347        # 2. Define call parameters 
13431348        if  prompt  is  not None  and  isinstance (prompt , str ):
@@ -1510,6 +1515,9 @@ def __call__(
15101515        num_warmup_steps  =  len (timesteps ) -  num_inference_steps  *  self .scheduler .order 
15111516        with  self .progress_bar (total = num_inference_steps ) as  progress_bar :
15121517            for  i , t  in  enumerate (timesteps ):
1518+                 if  self .interrupt :
1519+                     continue 
1520+ 
15131521                # expand the latents if we are doing classifier free guidance 
15141522                latent_model_input  =  torch .cat ([latents ] *  2 ) if  self .do_classifier_free_guidance  else  latents 
15151523                latent_model_input  =  self .scheduler .scale_model_input (latent_model_input , t )
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments