Skip to content

Commit d6a45e1

Browse files
committed
fix dumper
1 parent bf1d833 commit d6a45e1

File tree

2 files changed

+30
-10
lines changed

2 files changed

+30
-10
lines changed

autointent/_dump_tools.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
8686
attrs: dict[str, ModuleAttributes] = vars(obj)
8787
simple_attrs = {}
8888
arrays: dict[str, npt.NDArray[Any]] = {}
89+
containers = {}
8990

9091
Dumper.make_subdirectories(path, exists_ok)
9192

@@ -96,6 +97,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
9697
val.dump(path / Dumper.tags / key)
9798
elif isinstance(val, ModuleSimpleAttributes):
9899
simple_attrs[key] = val
100+
elif isinstance(val, dict):
101+
containers[key] = val
99102
elif isinstance(val, np.ndarray):
100103
arrays[key] = val
101104
elif isinstance(val, Embedder):
@@ -154,6 +157,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
154157
"module": val.__class__.__module__,
155158
"name": val.__class__.__name__,
156159
}
160+
# Save configuration if available
161+
if hasattr(val, 'get_config'):
162+
class_info['config'] = val.get_config()
157163
with (model_path / "class_info.json").open("w") as f:
158164
json.dump(class_info, f)
159165
except Exception as e:
@@ -174,6 +180,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
174180
with (path / Dumper.simple_attrs).open("w") as file:
175181
json.dump(simple_attrs, file, ensure_ascii=False, indent=4)
176182

183+
with (path / Dumper.containers / "containers.json").open("w") as f:
184+
json.dump(containers, f, ensure_ascii=False, indent=4)
185+
177186
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
178187

179188
@staticmethod
@@ -275,32 +284,31 @@ def load( # noqa: C901, PLR0912, PLR0915
275284
try:
276285
with (model_dir / "class_info.json").open("r") as f:
277286
class_info = json.load(f)
278-
279287
module = __import__(class_info["module"], fromlist=[class_info["name"]])
280288
model_class = getattr(module, class_info["name"])
281-
282-
# Create model instance
283-
model = model_class()
284-
285-
# Load state dict
289+
config = class_info.get('config', {})
290+
# Initialize model with config if available
291+
model = model_class(**config)
286292
model.load_state_dict(torch.load(model_dir / "model.pt"))
287293
model.eval()
288294
torch_models[model_dir.name] = model
289-
except Exception as e: # noqa: PERF203
290-
msg = f"Error loading torch model {model_dir.name}: {e}"
291-
logger.exception(msg)
295+
except Exception as e:
296+
logger.exception(f"Error loading torch model {model_dir.name}: {e}")
292297
elif child.name == Dumper.containers:
293298
try:
294299
for container_file in child.iterdir():
300+
print(container_file)
295301
with container_file.open("r") as f:
296-
containers[container_file.stem] = json.load(f)
302+
containers = json.load(f)
297303
except Exception as e:
298304
msg = f"Error loading containers: {e}"
299305
logger.exception(msg)
300306
else:
301307
msg = f"Found unexpected child {child}"
302308
logger.error(msg)
303309

310+
print(containers)
311+
304312
obj.__dict__.update(
305313
tags
306314
| simple_attrs

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,15 @@ def load(self, model_path: str) -> None:
7272
"""
7373
state_dict = torch.load(model_path)
7474
self.load_state_dict(state_dict)
75+
76+
def get_config(self) -> dict:
77+
return {
78+
'vocab_size': self.vocab_size.item(),
79+
'n_classes': self.n_classes.item(),
80+
'embed_dim': self.embed_dim.item(),
81+
'kernel_sizes': self.kernel_sizes.tolist(),
82+
'num_filters': self.num_filters.item(),
83+
'dropout': self.dropout_rate.item(),
84+
'padding_idx': self.padding_idx.item(),
85+
'pretrained_embs': self.pretrained_embs,
86+
}

0 commit comments

Comments
 (0)