Skip to content

Commit 4680c08

Browse files
committed
fix
1 parent f7a6ca7 commit 4680c08

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

autointent/_dump_tools.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from autointent import Embedder, Ranker, VectorIndex
1717
from autointent.configs import CrossEncoderConfig, EmbedderConfig
18-
from autointent.modules.scoring._cnn.textcnn import TextCNN
1918
from autointent.schemas import TagsList
2019

2120
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
@@ -102,10 +101,10 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
102101
model_path = path / Dumper.torch_models / key
103102
model_path.mkdir(parents=True, exist_ok=True)
104103
try:
105-
torch.save(val._model.state_dict(), model_path / "model.pt")
104+
torch.save(val._model.state_dict(), model_path / "model.pt") # noqa: SLF001
106105
vocab_path = path / Dumper.torch_models / "vocab.json"
107106
with vocab_path.open("w") as f:
108-
json.dump(obj._vocab, f)
107+
json.dump(val._vocab, f) # noqa: SLF001
109108
class_info = {
110109
"module": val.__class__.__module__,
111110
"name": val.__class__.__name__,
@@ -255,15 +254,15 @@ def load( # noqa: PLR0912, C901, PLR0915
255254
try:
256255
with (model_dir / "class_info.json").open("r") as f:
257256
class_info = json.load(f)
258-
vocab_path = path / Dumper.torch_models / "vocab.json"
259-
with vocab_path.open("r") as f:
260-
obj._vocab = json.load(f)
261257

262258
module = __import__(class_info["module"], fromlist=[class_info["name"]])
263259
model_class = getattr(module, class_info["name"])
264260

265261
# Create model instance
266262
model = model_class()
263+
vocab_path = path / Dumper.torch_models / "vocab.json"
264+
with vocab_path.open("r") as f:
265+
model._vocab = json.load(f) # noqa: SLF001
267266

268267
# Load state dict
269268
model.load_state_dict(torch.load(model_dir / "model.pt"))

autointent/modules/scoring/_cnn/textcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def load(self, model_path: str) -> None:
7070
Args:
7171
model_path: Path to the saved model state dictionary.
7272
"""
73-
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
73+
state_dict = torch.load(model_path)
7474
self.load_state_dict(state_dict)

0 commit comments

Comments
 (0)