Skip to content

Commit 2ef3b49

Browse files
committed
Add run cancelling logic to extension manager
1 parent 3f79467 commit 2ef3b49

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
723723
@torch.no_grad()
724724
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
725725
def _new_invoke(self, context: InvocationContext) -> LatentsOutput:
726-
ext_manager = ExtensionsManager()
726+
ext_manager = ExtensionsManager(is_canceled=context.util.is_canceled)
727727

728728
device = TorchDevice.choose_torch_device()
729729
dtype = TorchDevice.choose_torch_dtype()

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@
33
from abc import ABC, abstractmethod
44
from contextlib import ExitStack, contextmanager
55
from functools import partial
6-
from typing import TYPE_CHECKING, Callable, Dict
6+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
77

88
import torch
99
from diffusers import UNet2DConditionModel
1010

11+
from invokeai.app.services.session_processor.session_processor_common import CanceledException
12+
1113
if TYPE_CHECKING:
1214
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
13-
from invokeai.backend.stable_diffusion.extensions import ExtensionBase
15+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
1416

1517

1618
class ExtCallbacksApi(ABC):
@@ -71,10 +73,11 @@ def __call__(self, *args, **kwargs):
7173

7274

7375
class ExtensionsManager:
74-
def __init__(self):
75-
self.extensions = []
76+
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
77+
self.extensions: List[ExtensionBase] = []
78+
self._is_canceled = is_canceled
7679

77-
self._callbacks = {}
80+
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
7881
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
7982

8083
def add_extension(self, ext: ExtensionBase):
@@ -93,6 +96,11 @@ def add_extension(self, ext: ExtensionBase):
9396
raise Exception(f"Unsupported injection type: {inj_info.type}")
9497

9598
def call_callback(self, name: str, *args, **kwargs):
99+
# TODO: add to patchers too?
100+
# and if so, should it be only in beginning of function or in for loop
101+
if self._is_canceled and self._is_canceled():
102+
raise CanceledException
103+
96104
if name in self._callbacks:
97105
self._callbacks[name](*args, **kwargs)
98106

0 commit comments

Comments
 (0)