Skip to content

Commit 6231290

Browse files
authored
Refactor/modules dumping and loading (#93)
* outline the idea * refactor a little bit and add functions implementations * wrap into Dumber class * stage progress * refactor vector index dumping and loading * refactor retrieval embedding module * refactor embedder dumping and loading * stage progress on implementing `_dump_tools.py` * add TODO * minor cleaning * refactor mlknn scorer * minor cleaning * remove vector index client * remove vector index client from modules implementation * add proper dumping and loading for list[Tag] * `NLITransformer` -> `CrossEncoder` * add cross encoder handling * DescriptionScorer: remove load/dump methods; properly define class attributes * fix codestyle and typing * DNNCScorer: proper attributes names * DescriptionScorer: proper attributes names * KNNScorer: properly define attributes * some cleaning * RerankScorer: decouple embedder and cross-encoder params * minor cleaning * MLKnnScorer: properly define class attributes * LinearScorer: remove load/dump methods; properly define class attributes * finish decoupling embedder and cross encoder params * minor cleaning * fix curcular import issue * minor test fix * remove `db_dir` argument everywhere * fix codestyle * some bug fix * some bug fix * remove references to vector index client and add cross encoder config to pipeline and context * add cross_encoder_config to CLI optimization * `CrossEncoder` -> `Ranker`
1 parent 2bf20ec commit 6231290

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+573
-1107
lines changed

autointent/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""This is AutoIntent API reference."""
22

3+
from ._ranker import Ranker
34
from ._embedder import Embedder
5+
from ._vector_index import VectorIndex
46
from ._dataset import Dataset
57
from ._hash import Hasher
68
from .context import Context
79
from ._pipeline import Pipeline
810

9-
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline"]
11+
12+
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline", "Ranker", "VectorIndex"]

autointent/_dump_tools.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import json
2+
import logging
3+
from pathlib import Path
4+
from typing import Any, TypeAlias
5+
6+
import joblib
7+
import numpy as np
8+
import numpy.typing as npt
9+
from sklearn.base import BaseEstimator
10+
11+
from autointent import Embedder, Ranker, VectorIndex
12+
from autointent.schemas import TagsList
13+
14+
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
15+
16+
ModuleAttributes: TypeAlias = (
17+
ModuleSimpleAttributes | TagsList | np.ndarray | Embedder | VectorIndex | BaseEstimator | Ranker # type: ignore[type-arg]
18+
)
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class Dumper:
24+
tags = "tags"
25+
simple_attrs = "simple_attrs.json"
26+
arrays = "arrays.npz"
27+
embedders = "embedders"
28+
indexes = "vector_indexes"
29+
estimators = "estimators"
30+
cross_encoders = "cross_encoders"
31+
32+
@staticmethod
33+
def make_subdirectories(path: Path) -> None:
34+
subdirectories = [
35+
path / Dumper.tags,
36+
path / Dumper.embedders,
37+
path / Dumper.indexes,
38+
path / Dumper.estimators,
39+
path / Dumper.cross_encoders,
40+
]
41+
for subdir in subdirectories:
42+
subdir.mkdir(parents=True, exist_ok=True)
43+
44+
@staticmethod
45+
def dump(obj: Any, path: Path) -> None: # noqa: ANN401
46+
"""Dump modules attributes to filestystem."""
47+
attrs: dict[str, ModuleAttributes] = vars(obj)
48+
simple_attrs = {}
49+
arrays: dict[str, npt.NDArray[Any]] = {}
50+
51+
Dumper.make_subdirectories(path)
52+
53+
for key, val in attrs.items():
54+
if isinstance(val, TagsList):
55+
val.dump(path / Dumper.tags / key)
56+
elif isinstance(val, ModuleSimpleAttributes):
57+
simple_attrs[key] = val
58+
elif isinstance(val, np.ndarray):
59+
arrays[key] = val
60+
elif isinstance(val, Embedder):
61+
val.dump(path / Dumper.embedders / key)
62+
elif isinstance(val, VectorIndex):
63+
val.dump(path / Dumper.indexes / key)
64+
elif isinstance(val, BaseEstimator):
65+
joblib.dump(val, path / Dumper.estimators / key)
66+
elif isinstance(val, Ranker):
67+
val.save(str(path / Dumper.cross_encoders / key))
68+
else:
69+
msg = f"Attribute {key} of type {type(val)} cannot be dumped to file system."
70+
logger.error(msg)
71+
72+
with (path / Dumper.simple_attrs).open("w") as file:
73+
json.dump(simple_attrs, file, ensure_ascii=False, indent=4)
74+
75+
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
76+
77+
@staticmethod
78+
def load(obj: Any, path: Path) -> None: # noqa: ANN401
79+
"""Load attributes from file system."""
80+
for child in path.iterdir():
81+
if child.name == Dumper.tags:
82+
tags = {tags_dump.name: TagsList.load(tags_dump) for tags_dump in child.iterdir()}
83+
elif child.name == Dumper.simple_attrs:
84+
with child.open() as file:
85+
simple_attrs = json.load(file)
86+
elif child.name == Dumper.arrays:
87+
arrays = dict(np.load(child))
88+
elif child.name == Dumper.embedders:
89+
# TODO propagate custom loading params (such as device, batch size etc) to this line
90+
embedders = {embedder_dump.name: Embedder.load(embedder_dump) for embedder_dump in child.iterdir()}
91+
elif child.name == Dumper.indexes:
92+
indexes = {index_dump.name: VectorIndex.load(index_dump) for index_dump in child.iterdir()}
93+
elif child.name == Dumper.estimators:
94+
estimators = {estimator_dump.name: joblib.load(estimator_dump) for estimator_dump in child.iterdir()}
95+
elif child.name == Dumper.cross_encoders:
96+
cross_encoders = {
97+
cross_encoder_dump.name: Ranker.load(cross_encoder_dump) for cross_encoder_dump in child.iterdir()
98+
}
99+
else:
100+
msg = f"Found unexpected child {child}"
101+
logger.error(msg)
102+
obj.__dict__.update(tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders)

autointent/_embedder.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@ def get_embeddings_path(filename: str) -> Path:
3737
class EmbedderDumpMetadata(TypedDict):
3838
"""Metadata for saving and loading an Embedder instance."""
3939

40+
model_name_or_path: str
41+
"""Name of the hugging face model or a local path to sentence transformers dump."""
42+
device: str
43+
"""Torch notation for CPU or CUDA."""
4044
batch_size: int
4145
"""Batch size used for embedding calculations."""
4246
max_length: int | None
4347
"""Maximum sequence length for the embedding model."""
48+
use_cache: bool
49+
"""Whether to use embeddings caching."""
4450

4551

4652
class Embedder:
@@ -51,12 +57,11 @@ class Embedder:
5157
embedding models, as well as calculating embeddings for input texts.
5258
"""
5359

54-
embedder_subdir: str = "sentence_transformers"
5560
metadata_dict_name: str = "metadata.json"
5661

5762
def __init__(
5863
self,
59-
model_name: str | Path,
64+
model_name_or_path: str | Path,
6065
device: str = "cpu",
6166
batch_size: int = 32,
6267
max_length: int | None = None,
@@ -71,16 +76,13 @@ def __init__(
7176
:param max_length: Maximum sequence length for the embedding model.
7277
:param use_cache: Flag indicating whether to cache intermediate embeddings.
7378
"""
74-
self.model_name = model_name
79+
self.model_name = model_name_or_path
7580
self.device = device
7681
self.batch_size = batch_size
7782
self.max_length = max_length
7883
self.use_cache = use_cache
7984

80-
if Path(model_name).exists():
81-
self.load(model_name)
82-
else:
83-
self.embedding_model = SentenceTransformer(str(model_name), device=device)
85+
self.embedding_model = SentenceTransformer(str(model_name_or_path), device=device)
8486

8587
self.logger = logging.getLogger(__name__)
8688

@@ -105,10 +107,7 @@ def clear_ram(self) -> None:
105107
def delete(self) -> None:
106108
"""Delete the embedding model and its associated directory."""
107109
self.clear_ram()
108-
shutil.rmtree(
109-
self.dump_dir,
110-
ignore_errors=True,
111-
) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
110+
shutil.rmtree(self.dump_dir)
112111

113112
def dump(self, path: Path) -> None:
114113
"""
@@ -118,28 +117,35 @@ def dump(self, path: Path) -> None:
118117
"""
119118
self.dump_dir = path
120119
metadata = EmbedderDumpMetadata(
120+
model_name_or_path=str(self.model_name),
121+
device=self.device,
121122
batch_size=self.batch_size,
122123
max_length=self.max_length,
124+
use_cache=self.use_cache,
123125
)
124126
path.mkdir(parents=True, exist_ok=True)
125-
self.embedding_model.save(str(path / self.embedder_subdir))
126127
with (path / self.metadata_dict_name).open("w") as file:
127128
json.dump(metadata, file, indent=4)
128129

129-
def load(self, path: Path | str) -> None:
130+
@classmethod
131+
def load(
132+
cls, path: Path | str, batch_size: int | None = None, use_cache: bool | None = None, device: str | None = None
133+
) -> "Embedder":
130134
"""
131135
Load the embedding model and metadata from disk.
132136
133137
:param path: Path to the directory where the model is stored.
134138
"""
135-
self.dump_dir = Path(path)
136-
path = Path(path)
137-
with (path / self.metadata_dict_name).open() as file:
139+
with (Path(path) / cls.metadata_dict_name).open() as file:
138140
metadata: EmbedderDumpMetadata = json.load(file)
139-
self.batch_size = metadata["batch_size"]
140-
self.max_length = metadata["max_length"]
141141

142-
self.embedding_model = SentenceTransformer(str(path / self.embedder_subdir), device=self.device)
142+
return cls(
143+
model_name_or_path=metadata["model_name_or_path"],
144+
device=device or metadata["device"],
145+
batch_size=batch_size or metadata["batch_size"],
146+
max_length=metadata["max_length"],
147+
use_cache=use_cache or metadata["use_cache"],
148+
)
143149

144150
def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
145151
"""

autointent/_pipeline/_cli_endpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ def optimize(cfg: OptimizationConfig) -> None:
2727

2828
logger.debug("Run Name: %s", cfg.logs.run_name)
2929
logger.debug("logs and assets: %s", cfg.logs.dirpath)
30-
logger.debug("Vector index path: %s", cfg.vector_index.db_dir)
3130

3231
# create shared objects for a whole pipeline
3332
context = Context(cfg.seed)
3433
cfg.logs.clear_ram = True
3534
context.configure_logging(cfg.logs)
3635
context.configure_vector_index(cfg.vector_index, cfg.embedder)
3736
context.configure_data(cfg.data)
37+
context.configure_cross_encoder(cfg.cross_encoder)
3838

3939
# run optimization
4040
search_space_config = load_config(cfg.task.search_space_path, context.is_multilabel(), logger)

autointent/_pipeline/_pipeline.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import yaml
1111

1212
from autointent import Context, Dataset
13-
from autointent.configs import EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
13+
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1414
from autointent.custom_types import NodeType
1515
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1616
from autointent.nodes import InferenceNode, NodeOptimizer
@@ -38,11 +38,12 @@ def __init__(
3838
self.logging_config = LoggingConfig(dump_dir=None)
3939
self.vector_index_config = VectorIndexConfig()
4040
self.embedder_config = EmbedderConfig()
41+
self.cross_encoder_config = CrossEncoderConfig()
4142
elif not isinstance(nodes[0], InferenceNode):
4243
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
4344
raise TypeError(msg)
4445

45-
def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig) -> None:
46+
def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig | CrossEncoderConfig) -> None:
4647
"""
4748
Set configuration for the optimizer.
4849
@@ -54,6 +55,8 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig)
5455
self.vector_index_config = config
5556
elif isinstance(config, EmbedderConfig):
5657
self.embedder_config = config
58+
elif isinstance(config, CrossEncoderConfig):
59+
self.cross_encoder_config = config
5760
else:
5861
msg = "unknown config type"
5962
raise TypeError(msg)
@@ -97,7 +100,7 @@ def _fit(self, context: Context) -> None:
97100
node_optimizer.fit(context) # type: ignore[union-attr]
98101
if not context.vector_index_config.save_db:
99102
self._logger.info("removing vector database from file system...")
100-
context.vector_index_client.delete_db()
103+
# TODO clear cache from appdirs
101104
self.context.callback_handler.end_run()
102105

103106
def _is_inference(self) -> bool:
@@ -124,6 +127,7 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False) -> Context:
124127
context.set_dataset(dataset, force_multilabel)
125128
context.configure_logging(self.logging_config)
126129
context.configure_vector_index(self.vector_index_config, self.embedder_config)
130+
context.configure_cross_encoder(self.cross_encoder_config)
127131

128132
self._fit(context)
129133

0 commit comments

Comments
 (0)