Skip to content

Commit 439fde2

Browse files
voorhsriapushgithub-actions[bot]
authored
Feat/cnn scorer (#230)
* added first version of CNNScorer * Update optimizer_config.schema.json * fix ruff * fix ruff and mypy * fix ruff * fix furr for cnn.py * added dump, removed cnn_kwargs * Revert "added dump, removed cnn_kwargs" This reverts commit efd9362. * added test_cnn, added dump, fixed cnn_kwargs * fix ruff * fix ruff test_cnn * Update cnn.py * fixed from_context * mypy fix * fixes * more fixes * fix of kernel sizes * Update test_cnn.py * fix ruff * fix test * fix mypy * added dump load test * Update test_cnn.py * Update test_cnn.py * Update test_cnn.py * try to fix dumper * Update _dump_tools.py * added dump-load * fix * fix ruff * Update _dump_tools.py * Update _dump_tools.py * change order in dump * Update _dump_tools.py * Update _dump_tools.py * fix dumper * fix ruff * Update _dump_tools.py * add get_implicit_initialization_params, fix mypy partially * Update textcnn.py * Update textcnn.py * Update textcnn.py * fixed test * fix mypy * Update tensorboard.py * Update _dump_tools.py * add `_wrappers` submodule * refactor cnn dumping * codestyle * bug fix * move vocabulary management to textcnn from cnnscorer * codestyle * remove containers * stylistic changes * add device independence * trigger ci * add proper file encoding everywhere --------- Co-authored-by: riapush <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent f01d2c9 commit 439fde2

File tree

17 files changed

+540
-7
lines changed

17 files changed

+540
-7
lines changed

autointent/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
"""This is AutoIntent API reference."""
22

33
from ._logging import setup_logging
4-
from ._ranker import Ranker
5-
from ._embedder import Embedder
6-
from ._vector_index import VectorIndex
4+
from ._wrappers import Ranker, Embedder, VectorIndex
75
from ._dataset import Dataset
86
from ._hash import Hasher
97
from .context import Context, load_dataset

autointent/_dump_tools.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from peft import PeftModel
1111
from pydantic import BaseModel
1212
from sklearn.base import BaseEstimator
13+
from torch import nn
1314
from transformers import ( # type: ignore[attr-defined]
1415
AutoModelForSequenceClassification,
1516
AutoTokenizer,
@@ -19,14 +20,15 @@
1920
)
2021

2122
from autointent import Embedder, Ranker, VectorIndex
23+
from autointent._wrappers import BaseTorchModule
2224
from autointent.configs import CrossEncoderConfig, EmbedderConfig
2325
from autointent.context.optimization_info import Artifact
2426
from autointent.schemas import TagsList
2527

2628
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
2729

2830
ModuleAttributes: TypeAlias = (
29-
ModuleSimpleAttributes | TagsList | np.ndarray | Embedder | VectorIndex | BaseEstimator | Ranker # type: ignore[type-arg]
31+
ModuleSimpleAttributes | TagsList | np.ndarray | Embedder | VectorIndex | BaseEstimator | Ranker | nn.Module # type: ignore[type-arg]
3032
)
3133

3234
logger = logging.getLogger(__name__)
@@ -43,6 +45,7 @@ class Dumper:
4345
pydantic_models: str = "pydantic"
4446
hf_models = "hf_models"
4547
hf_tokenizers = "hf_tokenizers"
48+
torch_models = "torch_models"
4649
ptuning_models = "ptuning_models"
4750

4851
@staticmethod
@@ -62,6 +65,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
6265
path / Dumper.pydantic_models,
6366
path / Dumper.hf_models,
6467
path / Dumper.hf_tokenizers,
68+
path / Dumper.torch_models,
6569
path / Dumper.ptuning_models,
6670
]
6771
for subdir in subdirectories:
@@ -139,6 +143,20 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
139143
except Exception as e:
140144
msg = f"Error dumping HF model {key}: {e}"
141145
logger.exception(msg)
146+
elif isinstance(val, BaseTorchModule):
147+
model_path = path / Dumper.torch_models / key
148+
model_path.mkdir(parents=True, exist_ok=True)
149+
try:
150+
class_info = {
151+
"module": val.__class__.__module__,
152+
"name": val.__class__.__name__,
153+
}
154+
with (model_path / "class_info.json").open("w") as f:
155+
json.dump(class_info, f)
156+
val.dump(model_path)
157+
except Exception as e:
158+
msg = f"Error dumping torch model {key}: {e}"
159+
logger.exception(msg)
142160
elif isinstance(val, PreTrainedTokenizer | PreTrainedTokenizerFast):
143161
tokenizer_path = path / Dumper.hf_tokenizers / key
144162
tokenizer_path.mkdir(parents=True, exist_ok=True)
@@ -174,6 +192,7 @@ def load( # noqa: C901, PLR0912, PLR0915
174192
pydantic_models: dict[str, Any] = {}
175193
hf_models: dict[str, Any] = {}
176194
hf_tokenizers: dict[str, Any] = {}
195+
torch_models: dict[str, Any] = {}
177196

178197
for child in path.iterdir():
179198
if child.name == Dumper.tags:
@@ -248,6 +267,18 @@ def load( # noqa: C901, PLR0912, PLR0915
248267
except Exception as e: # noqa: PERF203
249268
msg = f"Error loading HF tokenizer {tokenizer_dir.name}: {e}"
250269
logger.exception(msg)
270+
elif child.name == Dumper.torch_models:
271+
try:
272+
for model_dir in child.iterdir():
273+
with (model_dir / "class_info.json").open("r") as f:
274+
class_info = json.load(f)
275+
module = importlib.import_module(class_info["module"])
276+
model_class: BaseTorchModule = getattr(module, class_info["name"])
277+
model = model_class.load(model_dir)
278+
torch_models[model_dir.name] = model
279+
except Exception as e:
280+
msg = f"Error loading torch model {model_dir.name}: {e}"
281+
logger.exception(msg)
251282
else:
252283
msg = f"Found unexpected child {child}"
253284
logger.error(msg)
@@ -263,4 +294,5 @@ def load( # noqa: C901, PLR0912, PLR0915
263294
| pydantic_models
264295
| hf_models
265296
| hf_tokenizers
297+
| torch_models
266298
)

autointent/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from typing import TypeVar
44

5+
import torch
6+
57
T = TypeVar("T")
68

79

@@ -14,3 +16,12 @@ def _funcs_to_dict(*funcs: T) -> dict[str, T]:
1416
Dictionary of functions
1517
"""
1618
return {func.__name__: func for func in funcs} # type: ignore[attr-defined]
19+
20+
21+
def detect_device() -> str:
22+
"""Automatically detects CUDA, MPS and CPU."""
23+
if torch.cuda.is_available():
24+
return "cuda"
25+
if torch.mps.is_available():
26+
return "mps"
27+
return "cpu"

autointent/_wrappers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .ranker import Ranker
2+
from .embedder import Embedder
3+
from .vector_index import VectorIndex
4+
from .base_torch_module import BaseTorchModule
5+
6+
__all__ = ["BaseTorchModule", "Embedder", "Ranker", "VectorIndex"]
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
4+
import torch
5+
from torch import nn
6+
from typing_extensions import Self
7+
8+
9+
class BaseTorchModule(nn.Module, ABC):
10+
@abstractmethod
11+
def dump(self, path: Path) -> None:
12+
"""Dump torch module to disk.
13+
14+
This method encapsulates all the logic of dumping module's weights and
15+
hyperparameters required for initialization from disk and nice inference.
16+
17+
Args:
18+
path: path in file system
19+
"""
20+
21+
@classmethod
22+
@abstractmethod
23+
def load(cls, path: Path, device: str | None = None) -> Self:
24+
"""Load torch module from disk.
25+
26+
This method loads all weights and hyperparameters required for
27+
initialization from disk and inference.
28+
29+
Args:
30+
path: path in file system
31+
device: torch notation for CPU, CUDA, MPS, etc. By default, it is inferred automatically.
32+
"""
33+
34+
@property
35+
def device(self) -> torch.device:
36+
"""Torch device object where this module resides."""
37+
return next(self.parameters()).device
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from sentence_transformers import SentenceTransformer
2020
from sentence_transformers.similarity_functions import SimilarityFunction
2121

22-
from ._hash import Hasher
23-
from .configs import EmbedderConfig, TaskTypeEnum
22+
from autointent._hash import Hasher
23+
from autointent.configs import EmbedderConfig, TaskTypeEnum
2424

2525
logger = logging.getLogger(__name__)
2626

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import numpy as np
1515
import numpy.typing as npt
1616

17-
from autointent import Embedder
17+
from autointent._wrappers import Embedder
1818
from autointent.configs import EmbedderConfig, TaskTypeEnum, TokenizerConfig
1919
from autointent.custom_types import ListOfLabels
2020

autointent/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .scoring import (
1616
BERTLoRAScorer,
1717
BertScorer,
18+
CNNScorer,
1819
DescriptionScorer,
1920
DNNCScorer,
2021
KNNScorer,
@@ -48,6 +49,7 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
4849
SklearnScorer,
4950
MLKnnScorer,
5051
BertScorer,
52+
CNNScorer,
5153
BERTLoRAScorer,
5254
PTuningScorer,
5355
]

autointent/modules/scoring/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._bert import BertScorer
2+
from ._cnn import CNNScorer
23
from ._description import DescriptionScorer
34
from ._dnnc import DNNCScorer
45
from ._knn import KNNScorer, RerankScorer
@@ -11,6 +12,7 @@
1112
__all__ = [
1213
"BERTLoRAScorer",
1314
"BertScorer",
15+
"CNNScorer",
1416
"DNNCScorer",
1517
"DescriptionScorer",
1618
"KNNScorer",

0 commit comments

Comments
 (0)