Skip to content

Commit 384a1a6

Browse files
Merge branch 'main' into z-image_metadata_node
2 parents a05a626 + 0021404 commit 384a1a6

File tree

7 files changed

+247
-5
lines changed

7 files changed

+247
-5
lines changed

invokeai/app/services/config/config_default.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class InvokeAIAppConfig(BaseSettings):
8585
max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
8686
max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
8787
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
88+
model_cache_keep_alive_min: How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.
8889
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
8990
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
9091
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
@@ -165,9 +166,10 @@ class InvokeAIAppConfig(BaseSettings):
165166
max_cache_ram_gb: Optional[float] = Field(default=None, gt=0, description="The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.")
166167
max_cache_vram_gb: Optional[float] = Field(default=None, ge=0, description="The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.")
167168
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
169+
model_cache_keep_alive_min: float = Field(default=0, ge=0, description="How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.")
168170
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
169171
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
170-
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
172+
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
171173
# Deprecated CACHE configs
172174
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
173175
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")

invokeai/app/services/model_manager/model_manager_default.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def start(self, invoker: Invoker) -> None:
6060
service.start(invoker)
6161

6262
def stop(self, invoker: Invoker) -> None:
63+
# Shutdown the model cache to cancel any pending timers
64+
if hasattr(self._load, "ram_cache"):
65+
self._load.ram_cache.shutdown()
66+
6367
for service in [self._store, self._install, self._load]:
6468
if hasattr(service, "stop"):
6569
service.stop(invoker)
@@ -88,7 +92,10 @@ def build_model_manager(
8892
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
8993
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
9094
execution_device=execution_device or TorchDevice.choose_torch_device(),
95+
storage_device="cpu",
96+
log_memory_usage=app_config.log_memory_usage,
9197
logger=logger,
98+
keep_alive_minutes=app_config.model_cache_keep_alive_min,
9299
)
93100
loader = ModelLoadService(
94101
app_config=app_config,

invokeai/backend/model_manager/load/load_default.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _load_and_cache(self, config: AnyModelConfig, submodel_type: Optional[SubMod
7575

7676
config.path = str(self._get_model_path(config))
7777
self._ram_cache.make_room(self.get_size_fs(config, Path(config.path), submodel_type))
78-
self._logger.info(f"Loading model '{stats_name}' into RAM cache..., config={config}")
7978
loaded_model = self._load_model(config, submodel_type)
8079

8180
self._ram_cache.put(

invokeai/backend/model_manager/load/model_cache/model_cache.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ def wrapper(self, *args, **kwargs):
5555
return wrapper
5656

5757

58+
def record_activity(method: Callable[..., Any]) -> Callable[..., Any]:
59+
"""A decorator that records activity after a method completes successfully.
60+
61+
Note: This decorator should be applied to methods that already hold self._lock.
62+
"""
63+
64+
@wraps(method)
65+
def wrapper(self, *args, **kwargs):
66+
result = method(self, *args, **kwargs)
67+
self._record_activity()
68+
return result
69+
70+
return wrapper
71+
72+
5873
@dataclass
5974
class CacheEntrySnapshot:
6075
cache_key: str
@@ -132,6 +147,7 @@ def __init__(
132147
storage_device: torch.device | str = "cpu",
133148
log_memory_usage: bool = False,
134149
logger: Optional[Logger] = None,
150+
keep_alive_minutes: float = 0,
135151
):
136152
"""Initialize the model RAM cache.
137153
@@ -151,6 +167,7 @@ def __init__(
151167
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
152168
behaviour.
153169
:param logger: InvokeAILogger to use (otherwise creates one)
170+
:param keep_alive_minutes: How long to keep models in cache after last use (in minutes). 0 means keep indefinitely.
154171
"""
155172
self._enable_partial_loading = enable_partial_loading
156173
self._keep_ram_copy_of_weights = keep_ram_copy_of_weights
@@ -182,6 +199,12 @@ def __init__(
182199
self._on_cache_miss_callbacks: set[CacheMissCallback] = set()
183200
self._on_cache_models_cleared_callbacks: set[CacheModelsClearedCallback] = set()
184201

202+
# Keep-alive timeout support
203+
self._keep_alive_minutes = keep_alive_minutes
204+
self._last_activity_time: Optional[float] = None
205+
self._timeout_timer: Optional[threading.Timer] = None
206+
self._shutdown_event = threading.Event()
207+
185208
def on_cache_hit(self, cb: CacheHitCallback) -> Callable[[], None]:
186209
self._on_cache_hit_callbacks.add(cb)
187210

@@ -190,7 +213,7 @@ def unsubscribe() -> None:
190213

191214
return unsubscribe
192215

193-
def on_cache_miss(self, cb: CacheHitCallback) -> Callable[[], None]:
216+
def on_cache_miss(self, cb: CacheMissCallback) -> Callable[[], None]:
194217
self._on_cache_miss_callbacks.add(cb)
195218

196219
def unsubscribe() -> None:
@@ -218,7 +241,78 @@ def stats(self, stats: CacheStats) -> None:
218241
"""Set the CacheStats object for collecting cache statistics."""
219242
self._stats = stats
220243

244+
def _record_activity(self) -> None:
245+
"""Record model activity and reset the timeout timer if configured.
246+
247+
Note: This method should only be called when self._lock is already held.
248+
"""
249+
if self._keep_alive_minutes <= 0:
250+
return
251+
252+
self._last_activity_time = time.time()
253+
254+
# Cancel any existing timer
255+
if self._timeout_timer is not None:
256+
self._timeout_timer.cancel()
257+
258+
# Start a new timer
259+
timeout_seconds = self._keep_alive_minutes * 60
260+
self._timeout_timer = threading.Timer(timeout_seconds, self._on_timeout)
261+
# Set as daemon so it doesn't prevent application shutdown
262+
self._timeout_timer.daemon = True
263+
self._timeout_timer.start()
264+
self._logger.debug(f"Model cache activity recorded. Timeout set to {self._keep_alive_minutes} minutes.")
265+
221266
@synchronized
267+
@record_activity
268+
def _on_timeout(self) -> None:
269+
"""Called when the keep-alive timeout expires. Clears the model cache."""
270+
if self._shutdown_event.is_set():
271+
return
272+
273+
# Double-check if there has been activity since the timer was set
274+
# This handles the race condition where activity occurred just before the timer fired
275+
if self._last_activity_time is not None and self._keep_alive_minutes > 0:
276+
elapsed_minutes = (time.time() - self._last_activity_time) / 60
277+
if elapsed_minutes < self._keep_alive_minutes:
278+
# Activity occurred, don't clear cache
279+
self._logger.debug(
280+
f"Model cache timeout fired but activity detected {elapsed_minutes:.2f} minutes ago. "
281+
f"Skipping cache clear."
282+
)
283+
return
284+
285+
# Check if there are any unlocked models that can be cleared
286+
unlocked_models = [key for key, entry in self._cached_models.items() if not entry.is_locked]
287+
288+
if len(unlocked_models) > 0:
289+
self._logger.info(
290+
f"Model cache keep-alive timeout of {self._keep_alive_minutes} minutes expired. "
291+
f"Clearing {len(unlocked_models)} unlocked model(s) from cache."
292+
)
293+
# Clear the cache by requesting a very large amount of space.
294+
# This is the same logic used by the "Clear Model Cache" button.
295+
# Using 1000 GB ensures all unlocked models are removed.
296+
self._make_room_internal(1000 * GB)
297+
elif len(self._cached_models) > 0:
298+
# All models are locked, don't log at info level
299+
self._logger.debug(
300+
f"Model cache timeout fired but all {len(self._cached_models)} model(s) are locked. "
301+
f"Skipping cache clear."
302+
)
303+
else:
304+
self._logger.debug("Model cache timeout fired but cache is already empty.")
305+
306+
@synchronized
307+
def shutdown(self) -> None:
308+
"""Shutdown the model cache, cancelling any pending timers."""
309+
self._shutdown_event.set()
310+
if self._timeout_timer is not None:
311+
self._timeout_timer.cancel()
312+
self._timeout_timer = None
313+
314+
@synchronized
315+
@record_activity
222316
def put(self, key: str, model: AnyModel) -> None:
223317
"""Add a model to the cache."""
224318
if key in self._cached_models:
@@ -228,7 +322,7 @@ def put(self, key: str, model: AnyModel) -> None:
228322
return
229323

230324
size = calc_model_size_by_data(self._logger, model)
231-
self.make_room(size)
325+
self._make_room_internal(size)
232326

233327
# Inject custom modules into the model.
234328
if isinstance(model, torch.nn.Module):
@@ -272,6 +366,7 @@ def _get_cache_snapshot(self) -> dict[str, CacheEntrySnapshot]:
272366
return overview
273367

274368
@synchronized
369+
@record_activity
275370
def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
276371
"""Retrieve a model from the cache.
277372
@@ -309,9 +404,11 @@ def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord:
309404
self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})")
310405
for cb in self._on_cache_hit_callbacks:
311406
cb(model_key=key, cache_snapshot=self._get_cache_snapshot())
407+
312408
return cache_entry
313409

314410
@synchronized
411+
@record_activity
315412
def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> None:
316413
"""Lock a model for use and move it into VRAM."""
317414
if cache_entry.key not in self._cached_models:
@@ -348,6 +445,7 @@ def lock(self, cache_entry: CacheRecord, working_mem_bytes: Optional[int]) -> No
348445
self._log_cache_state()
349446

350447
@synchronized
448+
@record_activity
351449
def unlock(self, cache_entry: CacheRecord) -> None:
352450
"""Unlock a model."""
353451
if cache_entry.key not in self._cached_models:
@@ -691,6 +789,10 @@ def make_room(self, bytes_needed: int) -> None:
691789
external references to the model, there's nothing that the cache can do about it, and those models will not be
692790
garbage-collected.
693791
"""
792+
self._make_room_internal(bytes_needed)
793+
794+
def _make_room_internal(self, bytes_needed: int) -> None:
795+
"""Internal implementation of make_room(). Assumes the lock is already held."""
694796
self._logger.debug(f"Making room for {bytes_needed / MB:.2f}MB of RAM.")
695797
self._log_cache_state(title="Before dropping models:")
696798

invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def _load_from_singlefile(
140140
# Some weights of the model checkpoint were not used when initializing CLIPTextModelWithProjection:
141141
# ['text_model.embeddings.position_ids']
142142

143-
self._logger.info(f"Loading model from single file at {config.path} using {load_class.__name__}")
144143
with SilenceWarnings():
145144
pipeline = load_class.from_single_file(config.path, torch_dtype=self._torch_dtype)
146145

invokeai/frontend/web/src/services/api/schema.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13036,6 +13036,7 @@ export type components = {
1303613036
* max_cache_ram_gb: The maximum amount of CPU RAM to use for model caching in GB. If unset, the limit will be configured based on the available RAM. In most cases, it is recommended to leave this unset.
1303713037
* max_cache_vram_gb: The amount of VRAM to use for model caching in GB. If unset, the limit will be configured based on the available VRAM and the device_working_mem_gb. In most cases, it is recommended to leave this unset.
1303813038
* log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
13039+
* model_cache_keep_alive_min: How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.
1303913040
* device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
1304013041
* enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
1304113042
* keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
@@ -13279,6 +13280,12 @@ export type components = {
1327913280
* @default false
1328013281
*/
1328113282
log_memory_usage?: boolean;
13283+
/**
13284+
* Model Cache Keep Alive Min
13285+
* @description How long to keep models in cache after last use, in minutes. A value of 0 (the default) means models are kept in cache indefinitely. If no model generations occur within the timeout period, the model cache is cleared using the same logic as the 'Clear Model Cache' button.
13286+
* @default 0
13287+
*/
13288+
model_cache_keep_alive_min?: number;
1328213289
/**
1328313290
* Device Working Mem Gb
1328413291
* @description The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.

0 commit comments

Comments
 (0)