Skip to content

Commit 0566f40

Browse files
committed
resolve conflicts
2 parents a90050c + cca9c0d commit 0566f40

File tree

34 files changed

+290
-90
lines changed

34 files changed

+290
-90
lines changed

Makefile

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,11 @@ test-docs: docs
3838
serve-docs: docs
3939
$(poetry) python -m http.server -d docs/build/html 8333
4040

41+
.PHONY: clean-docs
42+
clean-docs:
43+
rm -rf docs/build
44+
rm -rf docs/source/autoapi
45+
rm -rf docs/source/tutorials
46+
4147
.PHONY: all
4248
all: lint

autointent/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ._embedder import Embedder
22
from ._dataset import Dataset
3+
from ._hash import Hasher
34
from .context import Context
45
from ._pipeline import Pipeline
56

6-
__all__ = ["Context", "Dataset", "Embedder", "Pipeline"]
7+
__all__ = ["Context", "Dataset", "Embedder", "Hasher", "Pipeline"]

autointent/_dataset/_dataset.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from datasets import ClassLabel, Sequence, concatenate_datasets, get_dataset_config_names, load_dataset
99
from datasets import Dataset as HFDataset
10-
from typing_extensions import Self
1110

1211
from autointent.custom_types import LabelType, Split
1312
from autointent.schemas import Intent, Tag
@@ -122,7 +121,7 @@ def dump(self) -> dict[str, list[dict[str, Any]]]:
122121
"""
123122
return {split_name: split.to_list() for split_name, split in self.items()}
124123

125-
def encode_labels(self) -> Self:
124+
def encode_labels(self) -> "Dataset":
126125
"""
127126
Encode dataset labels into one-hot or multilabel format.
128127
@@ -133,7 +132,7 @@ def encode_labels(self) -> Self:
133132
self._encoded_labels = True
134133
return self
135134

136-
def to_multilabel(self) -> Self:
135+
def to_multilabel(self) -> "Dataset":
137136
"""
138137
Convert dataset labels to multilabel format.
139138

autointent/_dataset/_validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""File with definitions of DatasetReader and DatasetValidator."""
22

33
from pydantic import BaseModel, model_validator
4-
from typing_extensions import Self
54

65
from autointent.schemas import Intent, Sample
76

@@ -21,7 +20,7 @@ class DatasetReader(BaseModel):
2120
intents: list[Intent] = []
2221

2322
@model_validator(mode="after")
24-
def validate_dataset(self) -> Self:
23+
def validate_dataset(self) -> "DatasetReader":
2524
"""
2625
Validate the dataset by ensuring intents and data splits are consistent.
2726
@@ -33,7 +32,7 @@ def validate_dataset(self) -> Self:
3332
self._validate_split(split)
3433
return self
3534

36-
def _validate_intents(self) -> Self:
35+
def _validate_intents(self) -> "DatasetReader":
3736
"""
3837
Validate the intents by checking their IDs for sequential order.
3938
@@ -52,7 +51,7 @@ def _validate_intents(self) -> Self:
5251
raise ValueError(message)
5352
return self
5453

55-
def _validate_split(self, split: list[Sample]) -> Self:
54+
def _validate_split(self, split: list[Sample]) -> "DatasetReader":
5655
"""
5756
Validate a dataset split to ensure all sample labels reference valid intent IDs.
5857

autointent/_embedder.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,27 @@
1212

1313
import numpy as np
1414
import numpy.typing as npt
15+
from appdirs import user_cache_dir
1516
from sentence_transformers import SentenceTransformer
1617

18+
from ._hash import Hasher
19+
20+
21+
def get_embeddings_path(filename: str) -> Path:
22+
"""
23+
Get the path to the embeddings file.
24+
25+
This function constructs the full path to an embeddings file stored
26+
in a specific directory under the user's home directory. The embeddings
27+
file is named based on the provided filename, with the `.npy` extension
28+
added.
29+
30+
:param filename: The name of the embeddings file (without extension).
31+
32+
:return: The full path to the embeddings file.
33+
"""
34+
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
35+
1736

1837
class EmbedderDumpMetadata(TypedDict):
1938
"""Metadata for saving and loading an Embedder instance."""
@@ -41,6 +60,7 @@ def __init__(
4160
device: str = "cpu",
4261
batch_size: int = 32,
4362
max_length: int | None = None,
63+
use_cache: bool = False,
4464
) -> None:
4565
"""
4666
Initialize the Embedder.
@@ -49,11 +69,13 @@ def __init__(
4969
:param device: Device to run the model on (e.g., "cpu", "cuda").
5070
:param batch_size: Batch size for embedding calculations.
5171
:param max_length: Maximum sequence length for the embedding model.
72+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
5273
"""
5374
self.model_name = model_name
5475
self.device = device
5576
self.batch_size = batch_size
5677
self.max_length = max_length
78+
self.use_cache = use_cache
5779

5880
if Path(model_name).exists():
5981
self.load(model_name)
@@ -62,6 +84,18 @@ def __init__(
6284

6385
self.logger = logging.getLogger(__name__)
6486

87+
def __hash__(self) -> int:
88+
"""
89+
Compute a hash value for the Embedder.
90+
91+
:returns: The hash value of the Embedder.
92+
"""
93+
hasher = Hasher()
94+
for parameter in self.embedding_model.parameters():
95+
hasher.update(parameter.detach().cpu().numpy())
96+
hasher.update(self.max_length)
97+
return hasher.intdigest()
98+
6599
def clear_ram(self) -> None:
66100
"""Move the embedding model to CPU and delete it from memory."""
67101
self.logger.debug("Clearing embedder %s from memory", self.model_name)
@@ -114,18 +148,35 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
114148
:param utterances: List of input texts to calculate embeddings for.
115149
:return: A numpy array of embeddings.
116150
"""
151+
if self.use_cache:
152+
hasher = Hasher()
153+
hasher.update(self)
154+
hasher.update(utterances)
155+
156+
embeddings_path = get_embeddings_path(hasher.hexdigest())
157+
if embeddings_path.exists():
158+
return np.load(embeddings_path) # type: ignore[no-any-return]
159+
117160
self.logger.debug(
118161
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, device=%s",
119162
self.model_name,
120163
self.batch_size,
121164
str(self.max_length),
122165
self.device,
123166
)
167+
124168
if self.max_length is not None:
125169
self.embedding_model.max_seq_length = self.max_length
126-
return self.embedding_model.encode(
170+
171+
embeddings = self.embedding_model.encode(
127172
utterances,
128173
convert_to_numpy=True,
129174
batch_size=self.batch_size,
130175
normalize_embeddings=True,
131-
) # type: ignore[return-value]
176+
)
177+
178+
if self.use_cache:
179+
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
180+
np.save(embeddings_path, embeddings)
181+
182+
return embeddings # type: ignore[return-value]

autointent/_hash.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""This module provides functionality for hashing data using the xxhash algorithm."""
2+
3+
import pickle
4+
from typing import Any
5+
6+
import xxhash
7+
8+
9+
class Hasher:
10+
"""
11+
A class that provides methods for hashing data using xxhash.
12+
13+
This class supports both a class-level method for generating hashes from
14+
any given value, as well as an instance-level method for progressively
15+
updating a hash state with new values.
16+
"""
17+
18+
def __init__(self) -> None:
19+
"""
20+
Initialize the Hasher instance and sets up the internal xxhash state.
21+
22+
This state will be used for progressively hashing values using the
23+
`update` method and obtaining the final digest using `hexdigest`.
24+
"""
25+
self._state = xxhash.xxh64()
26+
27+
@classmethod
28+
def hash(cls, value: Any) -> int: # noqa: ANN401
29+
"""
30+
Generate a hash for the given value using xxhash.
31+
32+
:param value: The value to be hashed. This can be any Python object.
33+
34+
:return: The resulting hash digest as a hexadecimal string.
35+
"""
36+
if hasattr(value, "__hash__") and value.__hash__ not in {None, object.__hash__}:
37+
return hash(value)
38+
return xxhash.xxh64(pickle.dumps(value)).intdigest()
39+
40+
def update(self, value: Any) -> None: # noqa: ANN401
41+
"""
42+
Update the internal hash state with the provided value.
43+
44+
This method will first hash the type of the value, then hash the value
45+
itself, and update the internal state accordingly.
46+
47+
:param value: The value to update the hash state with.
48+
"""
49+
self._state.update(str(type(value)).encode())
50+
self._state.update(str(self.hash(value)).encode())
51+
52+
def hexdigest(self) -> str:
53+
"""
54+
Return the current hash digest as a hexadecimal string.
55+
56+
This method should be called after one or more `update` calls to get
57+
the final hash result.
58+
59+
:return: The resulting hash digest as a hexadecimal string.
60+
"""
61+
return self._state.hexdigest()
62+
63+
def intdigest(self) -> int:
64+
"""
65+
Return the current hash digest as an integer.
66+
67+
This method should be called after one or more `update` calls to get
68+
the final hash result.
69+
70+
:return: The resulting hash digest as an integer.
71+
"""
72+
return self._state.intdigest()

autointent/_pipeline/_pipeline.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
import numpy.typing as npt
1010
import yaml
11-
from typing_extensions import Self
1211

1312
from autointent import Context, Dataset
1413
from autointent.configs import EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
@@ -60,7 +59,7 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig)
6059
raise TypeError(msg)
6160

6261
@classmethod
63-
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str) -> Self:
62+
def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str) -> "Pipeline":
6463
"""
6564
Create pipeline optimizer from dictionary search space.
6665
@@ -73,7 +72,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str) -> S
7372
return cls(nodes)
7473

7574
@classmethod
76-
def default_optimizer(cls, multilabel: bool) -> Self:
75+
def default_optimizer(cls, multilabel: bool) -> "Pipeline":
7776
"""
7877
Create pipeline optimizer with default search space for given classification task.
7978
@@ -137,7 +136,7 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False, init_for_inferen
137136
return context
138137

139138
@classmethod
140-
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> Self:
139+
def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> "Pipeline":
141140
"""
142141
Create inference pipeline from dictionary config.
143142
@@ -147,7 +146,7 @@ def from_dict_config(cls, nodes_configs: list[dict[str, Any]]) -> Self:
147146
return cls.from_config([InferenceNodeConfig(**cfg) for cfg in nodes_configs])
148147

149148
@classmethod
150-
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> Self:
149+
def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> "Pipeline":
151150
"""
152151
Create inference pipeline from config.
153152
@@ -157,7 +156,7 @@ def from_config(cls, nodes_configs: list[InferenceNodeConfig]) -> Self:
157156
return cls(nodes)
158157

159158
@classmethod
160-
def load(cls, path: str | Path) -> Self:
159+
def load(cls, path: str | Path) -> "Pipeline":
161160
"""
162161
Load pipeline in inference mode.
163162

autointent/configs/_optimization_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class EmbedderConfig:
107107
"""Batch size for the embedder"""
108108
max_length: int | None = None
109109
"""Max length for the embedder. If None, the max length will be taken from model config"""
110+
use_cache: bool = False
111+
"""Flag indicating whether to cache embeddings for reuse, improving performance in repeated operations."""
110112

111113

112114
@dataclass

autointent/context/_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
6969
self.vector_index_config.db_dir,
7070
self.embedder_config.batch_size,
7171
self.embedder_config.max_length,
72+
self.embedder_config.use_cache,
7273
)
7374

7475
def configure_data(self, config: DataConfig) -> None:
@@ -189,6 +190,14 @@ def get_max_length(self) -> int | None:
189190
"""
190191
return self.vector_index_client.embedder_max_length
191192

193+
def get_use_cache(self) -> bool:
194+
"""
195+
Check if caching is enabled for the embedder.
196+
197+
:return: True if caching is enabled, False otherwise.
198+
"""
199+
return self.vector_index_client.embedder_use_cache
200+
192201
def get_dump_dir(self) -> Path | None:
193202
"""
194203
Get the directory for saving dumped modules.

autointent/context/vector_index_client/_vector_index.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
device: str,
3232
embedder_batch_size: int = 32,
3333
embedder_max_length: int | None = None,
34+
embedder_use_cache: bool = False,
3435
) -> None:
3536
"""
3637
Initialize the vector index.
@@ -39,13 +40,15 @@ def __init__(
3940
:param device: Device for running the embedding model (e.g., "cpu", "cuda").
4041
:param embedder_batch_size: Batch size for the embedder.
4142
:param embedder_max_length: Maximum sequence length for the embedder.
43+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
4244
"""
4345
self.model_name = model_name
4446
self.embedder = Embedder(
4547
model_name=model_name,
4648
batch_size=embedder_batch_size,
4749
device=device,
4850
max_length=embedder_max_length,
51+
use_cache=embedder_use_cache,
4952
)
5053
self.device = device
5154

0 commit comments

Comments
 (0)