3
3
from abc import ABC , abstractmethod
4
4
from contextlib import ExitStack , contextmanager
5
5
from functools import partial
6
- from typing import TYPE_CHECKING , Callable , Dict
6
+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional
7
7
8
8
import torch
9
9
from diffusers import UNet2DConditionModel
10
10
11
+ from invokeai .app .services .session_processor .session_processor_common import CanceledException
12
+
11
13
if TYPE_CHECKING :
12
14
from invokeai .backend .stable_diffusion .denoise_context import DenoiseContext
13
- from invokeai .backend .stable_diffusion .extensions import ExtensionBase
15
+ from invokeai .backend .stable_diffusion .extensions . base import ExtensionBase
14
16
15
17
16
18
class ExtCallbacksApi (ABC ):
@@ -71,10 +73,11 @@ def __call__(self, *args, **kwargs):
71
73
72
74
73
75
class ExtensionsManager :
74
- def __init__ (self ):
75
- self .extensions = []
76
+ def __init__ (self , is_canceled : Optional [Callable [[], bool ]] = None ):
77
+ self .extensions : List [ExtensionBase ] = []
78
+ self ._is_canceled = is_canceled
76
79
77
- self ._callbacks = {}
80
+ self ._callbacks : Dict [ str , CallbackInjectionPoint ] = {}
78
81
self .callbacks : ExtCallbacksApi = ProxyCallsClass (self .call_callback )
79
82
80
83
def add_extension (self , ext : ExtensionBase ):
@@ -93,6 +96,11 @@ def add_extension(self, ext: ExtensionBase):
93
96
raise Exception (f"Unsupported injection type: { inj_info .type } " )
94
97
95
98
def call_callback (self , name : str , * args , ** kwargs ):
99
+ # TODO: add to patchers too?
100
+ # and if so, should it be only in beginning of function or in for loop
101
+ if self ._is_canceled and self ._is_canceled ():
102
+ raise CanceledException
103
+
96
104
if name in self ._callbacks :
97
105
self ._callbacks [name ](* args , ** kwargs )
98
106
0 commit comments