Skip to content

Commit d63348b

Browse files
tidy(mm): clean up ModelOnDisk caching
1 parent 09449cf commit d63348b

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

invokeai/backend/model_manager/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,9 +2162,6 @@ def get_model_discriminator_value(v: Any) -> str:
21622162
# when AnyModelConfig is constructed dynamically using ModelConfigBase.all_config_classes
21632163
AnyModelConfig = Annotated[
21642164
Union[
2165-
# Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
2166-
# Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
2167-
# SD_1_2_XL_XLRefiner_CheckpointConfig
21682165
Annotated[FLUX_Unquantized_CheckpointConfig, FLUX_Unquantized_CheckpointConfig.get_tag()],
21692166
Annotated[FLUX_Quantized_BnB_NF4_CheckpointConfig, FLUX_Quantized_BnB_NF4_CheckpointConfig.get_tag()],
21702167
Annotated[FLUX_Quantized_GGUF_CheckpointConfig, FLUX_Quantized_GGUF_CheckpointConfig.get_tag()],

invokeai/backend/model_manager/model_on_disk.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
3030
self.hash_algo = hash_algo
3131
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
3232
# This prevents redundant computations during matching and parsing
33-
self.cache = {"_CACHED_STATE_DICTS": {}}
33+
self._state_dict_cache: dict[Path, Any] = {}
34+
self._metadata_cache: dict[Path, Any] = {}
3435

3536
def hash(self) -> str:
3637
return ModelHash(algorithm=self.hash_algo).hash(self.path)
@@ -47,13 +48,18 @@ def weight_files(self) -> set[Path]:
4748
return {f for f in self.path.rglob("*") if f.suffix in extensions}
4849

4950
def metadata(self, path: Optional[Path] = None) -> dict[str, str]:
51+
path = path or self.path
52+
if path in self._metadata_cache:
53+
return self._metadata_cache[path]
5054
try:
5155
with safe_open(self.path, framework="pt", device="cpu") as f:
5256
metadata = f.metadata()
5357
assert isinstance(metadata, dict)
54-
return metadata
5558
except Exception:
56-
return {}
59+
metadata = {}
60+
61+
self._metadata_cache[path] = metadata
62+
return metadata
5763

5864
def repo_variant(self) -> Optional[ModelRepoVariant]:
5965
if self.path.is_file():
@@ -73,10 +79,8 @@ def repo_variant(self) -> Optional[ModelRepoVariant]:
7379
return ModelRepoVariant.Default
7480

7581
def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
76-
sd_cache = self.cache["_CACHED_STATE_DICTS"]
77-
78-
if path in sd_cache:
79-
return sd_cache[path]
82+
if path in self._state_dict_cache:
83+
return self._state_dict_cache[path]
8084

8185
path = self.resolve_weight_file(path)
8286

@@ -111,7 +115,7 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
111115
raise ValueError(f"Unrecognized model extension: {path.suffix}")
112116

113117
state_dict = checkpoint.get("state_dict", checkpoint)
114-
sd_cache[path] = state_dict
118+
self._state_dict_cache[path] = state_dict
115119
return state_dict
116120

117121
def resolve_weight_file(self, path: Optional[Path] = None) -> Path:

0 commit comments

Comments
 (0)