@@ -30,7 +30,8 @@ def __init__(self, path: Path, hash_algo: HASHING_ALGORITHMS = "blake3_single"):
30
30
self .hash_algo = hash_algo
31
31
# Having a cache helps users of ModelOnDisk (i.e. configs) to save state
32
32
# This prevents redundant computations during matching and parsing
33
- self .cache = {"_CACHED_STATE_DICTS" : {}}
33
+ self ._state_dict_cache : dict [Path , Any ] = {}
34
+ self ._metadata_cache : dict [Path , Any ] = {}
34
35
35
36
def hash (self ) -> str :
36
37
return ModelHash (algorithm = self .hash_algo ).hash (self .path )
@@ -47,13 +48,18 @@ def weight_files(self) -> set[Path]:
47
48
return {f for f in self .path .rglob ("*" ) if f .suffix in extensions }
48
49
49
50
def metadata (self , path : Optional [Path ] = None ) -> dict [str , str ]:
51
+ path = path or self .path
52
+ if path in self ._metadata_cache :
53
+ return self ._metadata_cache [path ]
50
54
try :
51
55
with safe_open (self .path , framework = "pt" , device = "cpu" ) as f :
52
56
metadata = f .metadata ()
53
57
assert isinstance (metadata , dict )
54
- return metadata
55
58
except Exception :
56
- return {}
59
+ metadata = {}
60
+
61
+ self ._metadata_cache [path ] = metadata
62
+ return metadata
57
63
58
64
def repo_variant (self ) -> Optional [ModelRepoVariant ]:
59
65
if self .path .is_file ():
@@ -73,10 +79,8 @@ def repo_variant(self) -> Optional[ModelRepoVariant]:
73
79
return ModelRepoVariant .Default
74
80
75
81
def load_state_dict (self , path : Optional [Path ] = None ) -> StateDict :
76
- sd_cache = self .cache ["_CACHED_STATE_DICTS" ]
77
-
78
- if path in sd_cache :
79
- return sd_cache [path ]
82
+ if path in self ._state_dict_cache :
83
+ return self ._state_dict_cache [path ]
80
84
81
85
path = self .resolve_weight_file (path )
82
86
@@ -111,7 +115,7 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
111
115
raise ValueError (f"Unrecognized model extension: { path .suffix } " )
112
116
113
117
state_dict = checkpoint .get ("state_dict" , checkpoint )
114
- sd_cache [path ] = state_dict
118
+ self . _state_dict_cache [path ] = state_dict
115
119
return state_dict
116
120
117
121
def resolve_weight_file (self , path : Optional [Path ] = None ) -> Path :
0 commit comments