Skip to content

Commit d827821

Browse files
fix ModelStorage save/load
1 parent 088e1ae commit d827821

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

dialogue2graph/pipelines/model_storage.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def load(self, path: Path):
9898
try:
9999
with open(path, "r") as f:
100100
loaded_storage = yaml.safe_load(f)
101-
for key, config in loaded_storage.items():
102-
self.add(key=key, config=config, model_type=config.pop("model_type"))
101+
for key in loaded_storage:
102+
self.add(key=key, config=loaded_storage[key].get("config"), model_type=loaded_storage[key].get("model_type"))
103103
logger.debug(f"Loaded model configuration for '{key}'")
104104
logger.info(f"Successfully loaded {len(loaded_storage)} models from {path}")
105105
except Exception as e:
@@ -148,7 +148,11 @@ def save(self, path: str):
148148
logger.debug(f"Attempting to save model storage to {path}")
149149
try:
150150
with open(path, "w") as f:
151-
storage_dump = {k: v.config for k, v in self.storage.items()}
151+
storage_dump = {}
152+
for model_key in self.storage:
153+
storage_dump[model_key] = {}
154+
storage_dump[model_key]["config"] = self.storage[model_key].config
155+
storage_dump[model_key]["model_type"] = self.storage[model_key].model_type
152156
yaml.dump(storage_dump, f)
153157
logger.info(f"Saved {len(self.storage)} models to {path}")
154158
except Exception as e:
Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
d2g_llm_filling_llm:v1:
2-
name: o3-mini
3-
temperature: 1
2+
config:
3+
name: o3-mini
4+
temperature: 1
5+
model_type: llm
46
d2g_llm_grouping_llm:v1:
5-
name: gpt-4o-latest
6-
temperature: 0
7+
config:
8+
name: gpt-4o-latest
9+
temperature: 0
10+
model_type: llm
711
d2g_llm_sim_model:v1:
8-
model_kwargs:
9-
device: cpu
10-
model_name: cointegrated/LaBSE-en-ru
12+
config:
13+
model_kwargs:
14+
device: cpu
15+
model_name: cointegrated/LaBSE-en-ru
16+
model_type: emb
1117
my_model:
12-
name: gpt-3.5-turbo
18+
config:
19+
name: gpt-3.5-turbo
20+
model_type: llm

0 commit comments

Comments
 (0)