Skip to content

Commit f7a6ca7

Browse files
committed
added dump-load
1 parent 77f9751 commit f7a6ca7

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

autointent/_dump_tools.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,13 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
102102
model_path = path / Dumper.torch_models / key
103103
model_path.mkdir(parents=True, exist_ok=True)
104104
try:
105-
torch.save(val.state_dict(), model_path / "model.pt")
106-
# Save class info for loading
105+
torch.save(val._model.state_dict(), model_path / "model.pt")
106+
vocab_path = path / Dumper.torch_models / "vocab.json"
107+
with vocab_path.open("w") as f:
108+
json.dump(obj._vocab, f)
107109
class_info = {
108110
"module": val.__class__.__module__,
109111
"name": val.__class__.__name__,
110-
"is_textcnn": isinstance(val, TextCNN)
111112
}
112113
with (model_path / "class_info.json").open("w") as f:
113114
json.dump(class_info, f)
@@ -254,25 +255,15 @@ def load( # noqa: PLR0912, C901, PLR0915
254255
try:
255256
with (model_dir / "class_info.json").open("r") as f:
256257
class_info = json.load(f)
258+
vocab_path = path / Dumper.torch_models / "vocab.json"
259+
with vocab_path.open("r") as f:
260+
obj._vocab = json.load(f)
257261

258262
module = __import__(class_info["module"], fromlist=[class_info["name"]])
259263
model_class = getattr(module, class_info["name"])
260264

261265
# Create model instance
262-
if class_info.get("is_textcnn"):
263-
# For TextCNN, we need to get the parameters from the parent CNNScorer
264-
model = model_class(
265-
vocab_size=len(obj._vocab) if hasattr(obj, "_vocab") and obj._vocab else 0,
266-
n_classes=obj._n_classes if hasattr(obj, "_n_classes") else 0,
267-
embed_dim=obj.embed_dim if hasattr(obj, "embed_dim") else 128,
268-
kernel_sizes=obj.kernel_sizes if hasattr(obj, "kernel_sizes") else [3, 4, 5],
269-
num_filters=obj.num_filters if hasattr(obj, "num_filters") else 100,
270-
dropout=obj.dropout if hasattr(obj, "dropout") else 0.1,
271-
padding_idx=obj._pad_idx if hasattr(obj, "_pad_idx") else 0
272-
)
273-
else:
274-
# For other torch models, create with default parameters
275-
model = model_class()
266+
model = model_class()
276267

277268
# Load state dict
278269
model.load_state_dict(torch.load(model_dir / "model.pt"))

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ class TextCNN(nn.Module):
1010

1111
def __init__(
1212
self,
13-
vocab_size: int,
14-
n_classes: int,
13+
vocab_size: int = 0,
14+
n_classes: int = 0,
1515
embed_dim: int = 128,
1616
kernel_sizes: list[int] = [3, 4, 5], # noqa: B006
1717
num_filters: int = 100,
@@ -22,15 +22,28 @@ def __init__(
2222
"""Initialize TextCNN model."""
2323
super().__init__()
2424

25+
# Register model hyperparameters as buffers
26+
self.register_buffer("vocab_size", torch.tensor(vocab_size))
27+
self.register_buffer("n_classes", torch.tensor(n_classes))
28+
self.register_buffer("embed_dim", torch.tensor(embed_dim))
29+
self.register_buffer("kernel_sizes", torch.tensor(kernel_sizes))
30+
self.register_buffer("num_filters", torch.tensor(num_filters))
31+
self.register_buffer("dropout_rate", torch.tensor(dropout))
32+
self.register_buffer("padding_idx", torch.tensor(padding_idx))
33+
2534
if pretrained_embs is not None:
2635
_, embed_dim = pretrained_embs.shape
27-
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
36+
self.embedding = nn.Embedding.from_pretrained(pretrained_embs, freeze=True) # type: ignore[no-untyped-call]
37+
# Register pretrained embeddings as buffer if they exist
38+
self.register_buffer("pretrained_embs", pretrained_embs)
2839
else:
2940
self.embedding = nn.Embedding(
3041
num_embeddings=vocab_size,
3142
embedding_dim=embed_dim,
3243
padding_idx=padding_idx
3344
)
45+
# Register None for pretrained_embs buffer
46+
self.register_buffer("pretrained_embs", None)
3447

3548
self.convs = nn.ModuleList([
3649
nn.Conv1d(
@@ -50,3 +63,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
5063
concatenated: torch.Tensor = torch.cat(conved, dim=1)
5164
dropped: torch.Tensor = self.dropout(concatenated)
5265
return self.fc(dropped) # type: ignore[no-any-return]
66+
67+
def load(self, model_path: str) -> None:
68+
"""Load model from saved state.
69+
70+
Args:
71+
model_path: Path to the saved model state dictionary.
72+
"""
73+
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
74+
self.load_state_dict(state_dict)

0 commit comments

Comments
 (0)