|
15 | 15 |
|
16 | 16 | from autointent import Embedder, Ranker, VectorIndex |
17 | 17 | from autointent.configs import CrossEncoderConfig, EmbedderConfig |
18 | | -from autointent.modules.scoring._cnn.textcnn import TextCNN |
19 | 18 | from autointent.schemas import TagsList |
20 | 19 |
|
21 | 20 | ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg] |
@@ -102,10 +101,10 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915 |
102 | 101 | model_path = path / Dumper.torch_models / key |
103 | 102 | model_path.mkdir(parents=True, exist_ok=True) |
104 | 103 | try: |
105 | | - torch.save(val._model.state_dict(), model_path / "model.pt") |
| 104 | + torch.save(val._model.state_dict(), model_path / "model.pt") # noqa: SLF001 |
106 | 105 | vocab_path = path / Dumper.torch_models / "vocab.json" |
107 | 106 | with vocab_path.open("w") as f: |
108 | | - json.dump(obj._vocab, f) |
| 107 | + json.dump(val._vocab, f) # noqa: SLF001 |
109 | 108 | class_info = { |
110 | 109 | "module": val.__class__.__module__, |
111 | 110 | "name": val.__class__.__name__, |
@@ -255,15 +254,15 @@ def load( # noqa: PLR0912, C901, PLR0915 |
255 | 254 | try: |
256 | 255 | with (model_dir / "class_info.json").open("r") as f: |
257 | 256 | class_info = json.load(f) |
258 | | - vocab_path = path / Dumper.torch_models / "vocab.json" |
259 | | - with vocab_path.open("r") as f: |
260 | | - obj._vocab = json.load(f) |
261 | 257 |
|
262 | 258 | module = __import__(class_info["module"], fromlist=[class_info["name"]]) |
263 | 259 | model_class = getattr(module, class_info["name"]) |
264 | 260 |
|
265 | 261 | # Create model instance |
266 | 262 | model = model_class() |
| 263 | + vocab_path = path / Dumper.torch_models / "vocab.json" |
| 264 | + with vocab_path.open("r") as f: |
| 265 | + model._vocab = json.load(f) # noqa: SLF001 |
267 | 266 |
|
268 | 267 | # Load state dict |
269 | 268 | model.load_state_dict(torch.load(model_dir / "model.pt")) |
|
0 commit comments