Skip to content

Commit 661b305

Browse files
psychedeliciousMillu
authored andcommitted
feat(nodes): add enable, disable, status to invocation cache
- New routes to clear, enable, disable and get the status of the cache - Status includes hits, misses, size, max size, enabled - Add client cache queries and mutations, abstracted into hooks - Add invocation cache status area (next to queue status) w/ buttons
1 parent 20f7e44 commit 661b305

22 files changed

+683
-130
lines changed

invokeai/app/api/routers/app_info.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import BaseModel, Field
88

99
from invokeai.app.invocations.upscale import ESRGAN_MODELS
10+
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
1011
from invokeai.backend.image_util.invisible_watermark import InvisibleWatermark
1112
from invokeai.backend.image_util.patchmatch import PatchMatch
1213
from invokeai.backend.image_util.safety_checker import SafetyChecker
@@ -113,3 +114,33 @@ async def set_log_level(
113114
async def clear_invocation_cache() -> None:
114115
"""Clears the invocation cache"""
115116
ApiDependencies.invoker.services.invocation_cache.clear()
117+
118+
119+
@app_router.put(
120+
"/invocation_cache/enable",
121+
operation_id="enable_invocation_cache",
122+
responses={200: {"description": "The operation was successful"}},
123+
)
124+
async def enable_invocation_cache() -> None:
125+
"""Clears the invocation cache"""
126+
ApiDependencies.invoker.services.invocation_cache.enable()
127+
128+
129+
@app_router.put(
130+
"/invocation_cache/disable",
131+
operation_id="disable_invocation_cache",
132+
responses={200: {"description": "The operation was successful"}},
133+
)
134+
async def disable_invocation_cache() -> None:
135+
"""Clears the invocation cache"""
136+
ApiDependencies.invoker.services.invocation_cache.disable()
137+
138+
139+
@app_router.get(
140+
"/invocation_cache/status",
141+
operation_id="get_invocation_cache_status",
142+
responses={200: {"model": InvocationCacheStatus}},
143+
)
144+
async def get_invocation_cache_status() -> InvocationCacheStatus:
145+
"""Clears the invocation cache"""
146+
return ApiDependencies.invoker.services.invocation_cache.get_status()

invokeai/app/services/invocation_cache/invocation_cache_base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Union
33

44
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
5+
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
56

67

78
class InvocationCacheBase(ABC):
@@ -32,7 +33,7 @@ def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) ->
3233

3334
@abstractmethod
3435
def delete(self, key: Union[int, str]) -> None:
35-
"""Deleteds an invocation output from the cache"""
36+
"""Deletes an invocation output from the cache"""
3637
pass
3738

3839
@abstractmethod
@@ -44,3 +45,18 @@ def clear(self) -> None:
4445
def create_key(self, invocation: BaseInvocation) -> int:
4546
"""Gets the key for the invocation's cache item"""
4647
pass
48+
49+
@abstractmethod
50+
def disable(self) -> None:
51+
"""Disables the cache, overriding the max cache size"""
52+
pass
53+
54+
@abstractmethod
55+
def enable(self) -> None:
56+
"""Enables the cache, letting the the max cache size take effect"""
57+
pass
58+
59+
@abstractmethod
60+
def get_status(self) -> InvocationCacheStatus:
61+
"""Returns the status of the cache"""
62+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from pydantic import BaseModel, Field
2+
3+
4+
class InvocationCacheStatus(BaseModel):
5+
size: int = Field(description="The current size of the invocation cache")
6+
hits: int = Field(description="The number of cache hits")
7+
misses: int = Field(description="The number of cache misses")
8+
enabled: bool = Field(description="Whether the invocation cache is enabled")
9+
max_size: int = Field(description="The maximum size of the invocation cache")

invokeai/app/services/invocation_cache/invocation_cache_memory.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,25 @@
33

44
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
55
from invokeai.app.services.invocation_cache.invocation_cache_base import InvocationCacheBase
6+
from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus
67
from invokeai.app.services.invoker import Invoker
78

89

910
class MemoryInvocationCache(InvocationCacheBase):
1011
__cache: dict[Union[int, str], tuple[BaseInvocationOutput, str]]
1112
__max_cache_size: int
13+
__disabled: bool
14+
__hits: int
15+
__misses: int
1216
__cache_ids: Queue
1317
__invoker: Invoker
1418

1519
def __init__(self, max_cache_size: int = 0) -> None:
1620
self.__cache = dict()
1721
self.__max_cache_size = max_cache_size
22+
self.__disabled = False
23+
self.__hits = 0
24+
self.__misses = 0
1825
self.__cache_ids = Queue()
1926

2027
def start(self, invoker: Invoker) -> None:
@@ -25,15 +32,17 @@ def start(self, invoker: Invoker) -> None:
2532
self.__invoker.services.latents.on_deleted(self._delete_by_match)
2633

2734
def get(self, key: Union[int, str]) -> Optional[BaseInvocationOutput]:
28-
if self.__max_cache_size == 0:
35+
if self.__max_cache_size == 0 or self.__disabled:
2936
return
3037

3138
item = self.__cache.get(key, None)
3239
if item is not None:
40+
self.__hits += 1
3341
return item[0]
42+
self.__misses += 1
3443

3544
def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) -> None:
36-
if self.__max_cache_size == 0:
45+
if self.__max_cache_size == 0 or self.__disabled:
3746
return
3847

3948
if key not in self.__cache:
@@ -47,24 +56,41 @@ def save(self, key: Union[int, str], invocation_output: BaseInvocationOutput) ->
4756
pass
4857

4958
def delete(self, key: Union[int, str]) -> None:
50-
if self.__max_cache_size == 0:
59+
if self.__max_cache_size == 0 or self.__disabled:
5160
return
5261

5362
if key in self.__cache:
5463
del self.__cache[key]
5564

5665
def clear(self, *args, **kwargs) -> None:
57-
if self.__max_cache_size == 0:
66+
if self.__max_cache_size == 0 or self.__disabled:
5867
return
5968

6069
self.__cache.clear()
6170
self.__cache_ids = Queue()
71+
self.__misses = 0
72+
self.__hits = 0
6273

6374
def create_key(self, invocation: BaseInvocation) -> int:
6475
return hash(invocation.json(exclude={"id"}))
6576

77+
def disable(self) -> None:
78+
self.__disabled = True
79+
80+
def enable(self) -> None:
81+
self.__disabled = False
82+
83+
def get_status(self) -> InvocationCacheStatus:
84+
return InvocationCacheStatus(
85+
hits=self.__hits,
86+
misses=self.__misses,
87+
enabled=not self.__disabled,
88+
size=len(self.__cache),
89+
max_size=self.__max_cache_size,
90+
)
91+
6692
def _delete_by_match(self, to_match: str) -> None:
67-
if self.__max_cache_size == 0:
93+
if self.__max_cache_size == 0 or self.__disabled:
6894
return
6995

7096
keys_to_delete = set()

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,15 @@ class SessionQueueItemWithoutGraph(BaseModel):
162162
session_id: str = Field(
163163
description="The ID of the session associated with this queue item. The session doesn't exist in graph_executions until the queue item is executed."
164164
)
165-
field_values: Optional[list[NodeFieldValue]] = Field(
166-
default=None, description="The field values that were used for this queue item"
167-
)
168-
queue_id: str = Field(description="The id of the queue with which this item is associated")
169165
error: Optional[str] = Field(default=None, description="The error message if this queue item errored")
170166
created_at: Union[datetime.datetime, str] = Field(description="When this queue item was created")
171167
updated_at: Union[datetime.datetime, str] = Field(description="When this queue item was updated")
172168
started_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was started")
173169
completed_at: Optional[Union[datetime.datetime, str]] = Field(description="When this queue item was completed")
170+
queue_id: str = Field(description="The id of the queue with which this item is associated")
171+
field_values: Optional[list[NodeFieldValue]] = Field(
172+
default=None, description="The field values that were used for this queue item"
173+
)
174174

175175
@classmethod
176176
def from_dict(cls, queue_item_dict: dict) -> "SessionQueueItemDTO":

invokeai/frontend/web/public/locales/en.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,22 @@
264264
"graphQueued": "Graph queued",
265265
"graphFailedToQueue": "Failed to queue graph"
266266
},
267+
"invocationCache": {
268+
"invocationCache": "Invocation Cache",
269+
"cacheSize": "Cache Size",
270+
"maxCacheSize": "Max Cache Size",
271+
"hits": "Cache Hits",
272+
"misses": "Cache Misses",
273+
"clear": "Clear",
274+
"clearSucceeded": "Invocation Cache Cleared",
275+
"clearFailed": "Problem Clearing Invocation Cache",
276+
"enable": "Enable",
277+
"enableSucceeded": "Invocation Cache Enabled",
278+
"enableFailed": "Problem Enabling Invocation Cache",
279+
"disable": "Disable",
280+
"disableSucceeded": "Invocation Cache Disabled",
281+
"disableFailed": "Problem Disabling Invocation Cache"
282+
},
267283
"gallery": {
268284
"allImagesLoaded": "All Images Loaded",
269285
"assets": "Assets",

invokeai/frontend/web/src/app/types/invokeai.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ export type AppFeature =
2121
| 'multiselect'
2222
| 'pauseQueue'
2323
| 'resumeQueue'
24-
| 'prependQueue';
24+
| 'prependQueue'
25+
| 'invocationCache';
2526

2627
/**
2728
* A disable-able Stable Diffusion feature
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import IAIButton from 'common/components/IAIButton';
2+
import { memo } from 'react';
3+
import { useTranslation } from 'react-i18next';
4+
import { useClearInvocationCache } from '../hooks/useClearInvocationCache';
5+
6+
const ClearInvocationCacheButton = () => {
7+
const { t } = useTranslation();
8+
const { clearInvocationCache, isDisabled, isLoading } =
9+
useClearInvocationCache();
10+
11+
return (
12+
<IAIButton
13+
isDisabled={isDisabled}
14+
isLoading={isLoading}
15+
onClick={clearInvocationCache}
16+
>
17+
{t('invocationCache.clear')}
18+
</IAIButton>
19+
);
20+
};
21+
22+
export default memo(ClearInvocationCacheButton);
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { ButtonGroup } from '@chakra-ui/react';
2+
import { memo } from 'react';
3+
import { useTranslation } from 'react-i18next';
4+
import { useGetInvocationCacheStatusQuery } from 'services/api/endpoints/appInfo';
5+
import ClearInvocationCacheButton from './ClearInvocationCacheButton';
6+
import ToggleInvocationCacheButton from './ToggleInvocationCacheButton';
7+
import StatusStatGroup from './common/StatusStatGroup';
8+
import StatusStatItem from './common/StatusStatItem';
9+
10+
const InvocationCacheStatus = () => {
11+
const { data: cacheStatus } = useGetInvocationCacheStatusQuery(undefined, {
12+
pollingInterval: 5000,
13+
});
14+
const { t } = useTranslation();
15+
return (
16+
<StatusStatGroup>
17+
<StatusStatItem
18+
isDisabled={!cacheStatus?.enabled}
19+
label={t('invocationCache.cacheSize')}
20+
value={cacheStatus?.size ?? 0}
21+
/>
22+
<StatusStatItem
23+
isDisabled={!cacheStatus?.enabled}
24+
label={t('invocationCache.hits')}
25+
value={cacheStatus?.hits ?? 0}
26+
/>
27+
<StatusStatItem
28+
isDisabled={!cacheStatus?.enabled}
29+
label={t('invocationCache.misses')}
30+
value={cacheStatus?.misses ?? 0}
31+
/>
32+
<StatusStatItem
33+
isDisabled={!cacheStatus?.enabled}
34+
label={t('invocationCache.maxCacheSize')}
35+
value={cacheStatus?.max_size ?? 0}
36+
/>
37+
<ButtonGroup w={24} orientation="vertical" size="xs">
38+
<ClearInvocationCacheButton />
39+
<ToggleInvocationCacheButton />
40+
</ButtonGroup>
41+
</StatusStatGroup>
42+
);
43+
};
44+
45+
export default memo(InvocationCacheStatus);

invokeai/frontend/web/src/features/queue/components/QueueStatus.tsx

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,39 @@
1-
import { Stat, StatGroup, StatLabel, StatNumber } from '@chakra-ui/react';
21
import { memo } from 'react';
32
import { useTranslation } from 'react-i18next';
43
import { useGetQueueStatusQuery } from 'services/api/endpoints/queue';
4+
import StatusStatGroup from './common/StatusStatGroup';
5+
import StatusStatItem from './common/StatusStatItem';
56

67
const QueueStatus = () => {
78
const { data: queueStatus } = useGetQueueStatusQuery();
89
const { t } = useTranslation();
910
return (
10-
<StatGroup alignItems="center" justifyContent="center" w="full" h="full">
11-
<Stat w={24}>
12-
<StatLabel>{t('queue.in_progress')}</StatLabel>
13-
<StatNumber>{queueStatus?.queue.in_progress ?? 0}</StatNumber>
14-
</Stat>
15-
<Stat w={24}>
16-
<StatLabel>{t('queue.pending')}</StatLabel>
17-
<StatNumber>{queueStatus?.queue.pending ?? 0}</StatNumber>
18-
</Stat>
19-
<Stat w={24}>
20-
<StatLabel>{t('queue.completed')}</StatLabel>
21-
<StatNumber>{queueStatus?.queue.completed ?? 0}</StatNumber>
22-
</Stat>
23-
<Stat w={24}>
24-
<StatLabel>{t('queue.failed')}</StatLabel>
25-
<StatNumber>{queueStatus?.queue.failed ?? 0}</StatNumber>
26-
</Stat>
27-
<Stat w={24}>
28-
<StatLabel>{t('queue.canceled')}</StatLabel>
29-
<StatNumber>{queueStatus?.queue.canceled ?? 0}</StatNumber>
30-
</Stat>
31-
<Stat w={24}>
32-
<StatLabel>{t('queue.total')}</StatLabel>
33-
<StatNumber>{queueStatus?.queue.total}</StatNumber>
34-
</Stat>
35-
</StatGroup>
11+
<StatusStatGroup>
12+
<StatusStatItem
13+
label={t('queue.in_progress')}
14+
value={queueStatus?.queue.in_progress ?? 0}
15+
/>
16+
<StatusStatItem
17+
label={t('queue.pending')}
18+
value={queueStatus?.queue.pending ?? 0}
19+
/>
20+
<StatusStatItem
21+
label={t('queue.completed')}
22+
value={queueStatus?.queue.completed ?? 0}
23+
/>
24+
<StatusStatItem
25+
label={t('queue.failed')}
26+
value={queueStatus?.queue.failed ?? 0}
27+
/>
28+
<StatusStatItem
29+
label={t('queue.canceled')}
30+
value={queueStatus?.queue.canceled ?? 0}
31+
/>
32+
<StatusStatItem
33+
label={t('queue.total')}
34+
value={queueStatus?.queue.total ?? 0}
35+
/>
36+
</StatusStatGroup>
3637
);
3738
};
3839

0 commit comments

Comments
 (0)