Skip to content

Commit 83a86ab

Browse files
committed
Add unit tests for ExtensionsManager and ExtensionBase.
1 parent 0c56d4a commit 83a86ab

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class ExtensionsManager:
1818
def __init__(self, is_canceled: Optional[Callable[[], bool]] = None):
1919
self._is_canceled = is_canceled
2020

21+
# A list of extensions in the order that they were added to the ExtensionsManager.
2122
self._extensions: List[ExtensionBase] = []
2223
self._ordered_callbacks: Dict[ExtensionCallbackType, List[CallbackFunctionWithMetadata]] = {}
2324

@@ -38,6 +39,8 @@ def _regenerate_ordered_callbacks(self):
3839

3940
# Sort each callback list.
4041
for callback_type, callbacks in self._ordered_callbacks.items():
42+
# Note that sorted() is stable, so if two callbacks have the same order, the order that they extensions were
43+
# added will be preserved.
4144
self._ordered_callbacks[callback_type] = sorted(callbacks, key=lambda x: x.metadata.order)
4245

4346
def run_callback(self, callback_type: ExtensionCallbackType, ctx: DenoiseContext):
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from unittest import mock
2+
3+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
4+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
5+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
6+
7+
8+
class MockExtension(ExtensionBase):
9+
"""A mock ExtensionBase subclass for testing purposes."""
10+
11+
def __init__(self, x: int):
12+
super().__init__()
13+
self._x = x
14+
15+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
16+
def set_step_index(self, ctx: DenoiseContext):
17+
ctx.step_index = self._x
18+
19+
20+
def test_extension_base_callback_registration():
21+
"""Test that a callback can be successfully registered with an extension."""
22+
val = 5
23+
mock_extension = MockExtension(val)
24+
25+
mock_ctx = mock.MagicMock()
26+
27+
callbacks = mock_extension.get_callbacks()
28+
pre_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.PRE_DENOISE_LOOP, [])
29+
assert len(pre_denoise_loop_cbs) == 1
30+
31+
# Call the mock callback.
32+
pre_denoise_loop_cbs[0].function(mock_ctx)
33+
34+
# Confirm that the callback ran.
35+
assert mock_ctx.step_index == val
36+
37+
38+
def test_extension_base_empty_callback_type():
39+
"""Test that an empty list is returned when no callbacks are registered for a given callback type."""
40+
mock_extension = MockExtension(5)
41+
42+
# There should be no callbacks registered for POST_DENOISE_LOOP.
43+
callbacks = mock_extension.get_callbacks()
44+
45+
post_denoise_loop_cbs = callbacks.get(ExtensionCallbackType.POST_DENOISE_LOOP, [])
46+
assert len(post_denoise_loop_cbs) == 0
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from unittest import mock
2+
3+
import pytest
4+
5+
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
6+
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
7+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
8+
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
9+
10+
11+
class MockExtension(ExtensionBase):
12+
"""A mock ExtensionBase subclass for testing purposes."""
13+
14+
def __init__(self, x: int):
15+
super().__init__()
16+
self._x = x
17+
18+
# Note that order is not specified. It should default to 0.
19+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
20+
def set_step_index(self, ctx: DenoiseContext):
21+
ctx.step_index = self._x
22+
23+
24+
class MockExtensionLate(ExtensionBase):
25+
"""A mock ExtensionBase subclass with a high order value on its PRE_DENOISE_LOOP callback."""
26+
27+
def __init__(self, x: int):
28+
super().__init__()
29+
self._x = x
30+
31+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
32+
def set_step_index(self, ctx: DenoiseContext):
33+
ctx.step_index = self._x
34+
35+
36+
def test_extension_manager_run_callback():
37+
"""Test that run_callback runs all callbacks for the given callback type."""
38+
39+
em = ExtensionsManager()
40+
mock_extension_1 = MockExtension(1)
41+
em.add_extension(mock_extension_1)
42+
43+
mock_ctx = mock.MagicMock()
44+
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
45+
46+
assert mock_ctx.step_index == 1
47+
48+
49+
def test_extension_manager_run_callback_no_callbacks():
50+
"""Test that run_callback does not raise an error when there are no callbacks for the given callback type."""
51+
em = ExtensionsManager()
52+
mock_ctx = mock.MagicMock()
53+
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
54+
55+
56+
@pytest.mark.parametrize(
57+
["extension_1", "extension_2"],
58+
# Regardless of initialization order, we expect MockExtensionLate to run last.
59+
[(MockExtension(1), MockExtensionLate(2)), (MockExtensionLate(2), MockExtension(1))],
60+
)
61+
def test_extension_manager_order_callbacks(extension_1: ExtensionBase, extension_2: ExtensionBase):
62+
"""Test that run_callback runs callbacks in the correct order."""
63+
em = ExtensionsManager()
64+
em.add_extension(extension_1)
65+
em.add_extension(extension_2)
66+
67+
mock_ctx = mock.MagicMock()
68+
em.run_callback(ExtensionCallbackType.PRE_DENOISE_LOOP, mock_ctx)
69+
70+
assert mock_ctx.step_index == 2
71+
72+
73+
class MockExtensionStableSort(ExtensionBase):
74+
"""A mock extension with three PRE_DENOISE_LOOP callbacks, each with a different order value."""
75+
76+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=-1000)
77+
def early(self, ctx: DenoiseContext):
78+
pass
79+
80+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP)
81+
def middle(self, ctx: DenoiseContext):
82+
pass
83+
84+
@callback(ExtensionCallbackType.PRE_DENOISE_LOOP, order=1000)
85+
def late(self, ctx: DenoiseContext):
86+
pass
87+
88+
89+
def test_extension_manager_stable_sort():
90+
"""Test that when two callbacks have the same 'order' value, they are sorted based on the order they were added to
91+
the ExtensionsManager."""
92+
93+
em = ExtensionsManager()
94+
95+
mock_extension_1 = MockExtensionStableSort()
96+
mock_extension_2 = MockExtensionStableSort()
97+
98+
em.add_extension(mock_extension_1)
99+
em.add_extension(mock_extension_2)
100+
101+
expected_order = [
102+
mock_extension_1.early,
103+
mock_extension_2.early,
104+
mock_extension_1.middle,
105+
mock_extension_2.middle,
106+
mock_extension_1.late,
107+
mock_extension_2.late,
108+
]
109+
110+
# It's not ideal that we are accessing a private attribute here, but this was the most direct way to assert the
111+
# desired behaviour.
112+
assert [cb.function for cb in em._ordered_callbacks[ExtensionCallbackType.PRE_DENOISE_LOOP]] == expected_order

0 commit comments

Comments
 (0)