Skip to content

Commit 0c56d4a

Browse files
Ryan's suggested changes to extension manager/extensions
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 710dc6b commit 0c56d4a

File tree

6 files changed

+82
-112
lines changed

6 files changed

+82
-112
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
)
5858
from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0
5959
from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend
60+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6061
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6162
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
6263
from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP
@@ -790,7 +791,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
790791
ext_manager.add_extension(PreviewExt(step_callback))
791792

792793
# ext: t2i/ip adapter
793-
ext_manager.callbacks.setup(denoise_ctx, ext_manager)
794+
ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx)
794795

795796
unet_info = context.models.load(self.unet.unet)
796797
assert isinstance(unet_info.model, UNet2DConditionModel)

invokeai/backend/stable_diffusion/diffusion_backend.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from invokeai.app.services.config.config_default import get_config
99
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs
1010
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
11+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
1112
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
1213

1314

@@ -41,23 +42,23 @@ def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsMa
4142

4243
# ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed)
4344
# ext: preview[pre_denoise_loop, priority=low]
44-
ext_manager.callbacks.pre_denoise_loop(ctx, ext_manager)
45+
ext_manager.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, ctx)
4546

4647
for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.inputs.timesteps)): # noqa: B020
4748
# ext: inpaint (apply mask to latents on non-inpaint models)
48-
ext_manager.callbacks.pre_step(ctx, ext_manager)
49+
ext_manager.run_callback(ExtensionCallbackType.PRE_STEP, ctx)
4950

5051
# ext: tiles? [override: step]
5152
ctx.step_output = self.step(ctx, ext_manager)
5253

5354
# ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models)
5455
# ext: preview[post_step, priority=low]
55-
ext_manager.callbacks.post_step(ctx, ext_manager)
56+
ext_manager.run_callback(ExtensionCallbackType.POST_STEP, ctx)
5657

5758
ctx.latents = ctx.step_output.prev_sample
5859

5960
# ext: inpaint[post_denoise_loop] (restore unmasked part)
60-
ext_manager.callbacks.post_denoise_loop(ctx, ext_manager)
61+
ext_manager.run_callback(ExtensionCallbackType.POST_DENOISE_LOOP, ctx)
6162
return ctx.latents
6263

6364
@torch.inference_mode()
@@ -80,7 +81,7 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler
8081

8182
# ext: cfg_rescale [modify_noise_prediction]
8283
# TODO: rename
83-
ext_manager.callbacks.post_apply_cfg(ctx, ext_manager)
84+
ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx)
8485

8586
# compute the previous noisy sample x_t -> x_t-1
8687
step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs)
@@ -120,14 +121,14 @@ def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditio
120121
ctx.inputs.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode)
121122

122123
# ext: controlnet/ip/t2i [pre_unet]
123-
ext_manager.callbacks.pre_unet(ctx, ext_manager)
124+
ext_manager.run_callback(ExtensionCallbackType.PRE_UNET, ctx)
124125

125126
# ext: inpaint [pre_unet, priority=low]
126127
# or
127128
# ext: inpaint [override: unet_forward]
128129
noise_pred = self._unet_forward(**vars(ctx.unet_kwargs))
129130

130-
ext_manager.callbacks.post_unet(ctx, ext_manager)
131+
ext_manager.run_callback(ExtensionCallbackType.POST_UNET, ctx)
131132

132133
# clean up locals
133134
ctx.unet_kwargs = None
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from enum import Enum
2+
3+
4+
class ExtensionCallbackType(Enum):
5+
SETUP = "setup"
6+
PRE_DENOISE_LOOP = "pre_denoise_loop"
7+
POST_DENOISE_LOOP = "post_denoise_loop"
8+
PRE_STEP = "pre_step"
9+
POST_STEP = "post_step"
10+
PRE_UNET = "pre_unet"
11+
POST_UNET = "post_unet"
12+
POST_APPLY_CFG = "post_apply_cfg"

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,54 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
5+
from typing import TYPE_CHECKING, Callable, Dict, List
66

77
import torch
88
from diffusers import UNet2DConditionModel
99

1010
if TYPE_CHECKING:
1111
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
12+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
1213

1314

1415
@dataclass
15-
class InjectionInfo:
16-
type: str
17-
name: str
18-
order: Optional[int]
19-
function: Callable
20-
21-
22-
def callback(name: str, order: int = 0):
23-
def _decorator(func):
24-
func.__inj_info__ = {
25-
"type": "callback",
26-
"name": name,
27-
"order": order,
28-
}
29-
return func
16+
class CallbackMetadata:
17+
callback_type: ExtensionCallbackType
18+
order: int
19+
20+
21+
@dataclass
22+
class CallbackFunctionWithMetadata:
23+
metadata: CallbackMetadata
24+
function: Callable[[DenoiseContext], None]
25+
26+
27+
def callback(callback_type: ExtensionCallbackType, order: int = 0):
28+
def _decorator(function):
29+
function._ext_metadata = CallbackMetadata(
30+
callback_type=callback_type,
31+
order=order,
32+
)
33+
return function
3034

3135
return _decorator
3236

3337

3438
class ExtensionBase:
3539
def __init__(self):
36-
self.injections: List[InjectionInfo] = []
40+
self._callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
41+
42+
# Register all of the callback methods for this instance.
3743
for func_name in dir(self):
3844
func = getattr(self, func_name)
39-
if not callable(func) or not hasattr(func, "__inj_info__"):
40-
continue
41-
42-
self.injections.append(InjectionInfo(**func.__inj_info__, function=func))
45+
metadata = getattr(func, "_ext_metadata", None)
46+
if metadata is not None and isinstance(metadata, CallbackMetadata):
47+
if metadata.callback_type not in self._callbacks:
48+
self._callbacks[metadata.callback_type] = []
49+
self._callbacks[metadata.callback_type].append(CallbackFunctionWithMetadata(metadata, func))
50+
51+
def get_callbacks(self):
52+
return self._callbacks
4353

4454
@contextmanager
4555
def patch_extension(self, context: DenoiseContext):

invokeai/backend/stable_diffusion/extensions/preview.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import torch
77

8+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
89
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
910

1011
if TYPE_CHECKING:
1112
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
12-
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
1313

1414

1515
# TODO: change event to accept image instead of latents
@@ -29,8 +29,8 @@ def __init__(self, callback: Callable[[PipelineIntermediateState], None]):
2929
self.callback = callback
3030

3131
# do last so that all other changes shown
32-
@callback("pre_denoise_loop", order=1000)
33-
def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
32+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
33+
def initial_preview(self, ctx: DenoiseContext):
3434
self.callback(
3535
PipelineIntermediateState(
3636
step=-1,
@@ -42,8 +42,8 @@ def initial_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
4242
)
4343

4444
# do last so that all other changes shown
45-
@callback("post_step", order=1000)
46-
def step_preview(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
45+
@callback(ExtensionCallbackType.POST_STEP, order=1000)
46+
def step_preview(self, ctx: DenoiseContext):
4747
if hasattr(ctx.step_output, "denoised"):
4848
predicted_original = ctx.step_output.denoised
4949
elif hasattr(ctx.step_output, "pred_original_sample"):

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 24 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from __future__ import annotations
22

3-
from abc import ABC, abstractmethod
43
from contextlib import ExitStack, contextmanager
5-
from functools import partial
64
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
75

86
import torch
@@ -12,102 +10,50 @@
1210

1311
if TYPE_CHECKING:
1412
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
15-
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
16-
17-
18-
class ExtCallbacksApi(ABC):
19-
@abstractmethod
20-
def setup(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
21-
pass
22-
23-
@abstractmethod
24-
def pre_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
25-
pass
26-
27-
@abstractmethod
28-
def post_denoise_loop(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
29-
pass
30-
31-
@abstractmethod
32-
def pre_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
33-
pass
34-
35-
@abstractmethod
36-
def post_step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
37-
pass
38-
39-
@abstractmethod
40-
def pre_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
41-
pass
42-
43-
@abstractmethod
44-
def post_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
45-
pass
46-
47-
@abstractmethod
48-
def post_apply_cfg(self, ctx: DenoiseContext, ext_manager: ExtensionsManager):
49-
pass
50-
51-
52-
class ProxyCallsClass:
53-
def __init__(self, handler):
54-
self._handler = handler
55-
56-
def __getattr__(self, item):
57-
return partial(self._handler, item)
58-
59-
60-
class CallbackInjectionPoint:
61-
def __init__(self):
62-
self.handlers = {}
63-
64-
def add(self, func: Callable, order: int):
65-
if order not in self.handlers:
66-
self.handlers[order] = []
67-
self.handlers[order].append(func)
68-
69-
def __call__(self, *args, **kwargs):
70-
for order in sorted(self.handlers.keys(), reverse=True):
71-
for handler in self.handlers[order]:
72-
handler(*args, **kwargs)
13+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
14+
from invokeai.backend.stable_diffusion.extensions.base import CallbackFunctionWithMetadata, ExtensionBase
7315

7416

7517
class ExtensionsManager:
7618
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
77-
self.extensions: List[ExtensionBase] = []
7819
self._is_canceled = is_canceled
7920

80-
self._callbacks: Dict[str, CallbackInjectionPoint] = {}
81-
self.callbacks: ExtCallbacksApi = ProxyCallsClass(self.call_callback)
21+
self._extensions: List[ExtensionBase] = []
22+
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
8223

83-
def add_extension(self, ext: ExtensionBase):
84-
self.extensions.append(ext)
24+
def add_extension(self, extension: ExtensionBase):
25+
self._extensions.append(extension)
26+
self._regenerate_ordered_callbacks()
8527

86-
self._callbacks.clear()
28+
def _regenerate_ordered_callbacks(self):
29+
"""Regenerates self._ordered_callbacks. Intended to be called each time a new extension is added."""
30+
self._ordered_callbacks = {}
8731

88-
for ext in self.extensions:
89-
for inj_info in ext.injections:
90-
if inj_info.type == "callback":
91-
if inj_info.name not in self._callbacks:
92-
self._callbacks[inj_info.name] = CallbackInjectionPoint()
93-
self._callbacks[inj_info.name].add(inj_info.function, inj_info.order)
32+
# Fill the ordered callbacks dictionary.
33+
for extension in self._extensions:
34+
for callback_type, callbacks in extension.get_callbacks().items():
35+
if callback_type not in self._ordered_callbacks:
36+
self._ordered_callbacks[callback_type] = []
37+
self._ordered_callbacks[callback_type].extend(callbacks)
9438

95-
else:
96-
raise Exception(f"Unsupported injection type: {inj_info.type}")
39+
# Sort each callback list.
40+
for callback_type, callbacks in self._ordered_callbacks.items():
41+
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
9742

98-
def call_callback(self, name: str, *args, **kwargs):
43+
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
9944
# TODO: add to patchers too?
10045
# and if so, should it be only in beginning of function or in for loop
10146
if self._is_canceled and self._is_canceled():
10247
raise CanceledException
10348

104-
if name in self._callbacks:
105-
self._callbacks[name](*args, **kwargs)
49+
callbacks = self._ordered_callbacks.get(callback_type, [])
50+
for cb in callbacks:
51+
cb.function(ctx)
10652

10753
@contextmanager
10854
def patch_extensions(self, context: DenoiseContext):
10955
with ExitStack() as exit_stack:
110-
for ext in self.extensions:
56+
for ext in self._extensions:
11157
exit_stack.enter_context(ext.patch_extension(context))
11258

11359
yield None

0 commit comments

Comments
 (0)