-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path_embedder.py
More file actions
187 lines (150 loc) · 6.17 KB
/
_embedder.py
File metadata and controls
187 lines (150 loc) · 6.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
"""Module for managing embedding models using Sentence Transformers.
This module provides the `Embedder` class for managing, persisting, and loading
embedding models and calculating embeddings for input texts.
"""
import json
import logging
import shutil
from pathlib import Path
from typing import TypedDict
import numpy as np
import numpy.typing as npt
import torch
from appdirs import user_cache_dir
from sentence_transformers import SentenceTransformer
from ._hash import Hasher
from .configs import EmbedderConfig, TaskTypeEnum
def get_embeddings_path(filename: str) -> Path:
"""
Get the path to the embeddings file.
This function constructs the full path to an embeddings file stored
in a specific directory under the user's home directory. The embeddings
file is named based on the provided filename, with the `.npy` extension
added.
:param filename: The name of the embeddings file (without extension).
:return: The full path to the embeddings file.
"""
return Path(user_cache_dir("autointent")) / "embeddings" / f"{filename}.npy"
class EmbedderDumpMetadata(TypedDict):
"""Metadata for saving and loading an Embedder instance."""
model_name_or_path: str
"""Name of the hugging face model or a local path to sentence transformers dump."""
device: str | None
"""Torch notation for CPU or CUDA."""
batch_size: int
"""Batch size used for embedding calculations."""
max_length: int | None
"""Maximum sequence length for the embedding model."""
use_cache: bool
"""Whether to use embeddings caching."""
class Embedder:
"""
A wrapper for managing embedding models using Sentence Transformers.
This class handles initialization, saving, loading, and clearing of
embedding models, as well as calculating embeddings for input texts.
"""
metadata_dict_name: str = "metadata.json"
dump_dir: Path | None = None
def __init__(self, embedder_config: EmbedderConfig) -> None:
"""
Initialize the Embedder.
:param embedder_config: Config of embedder.
"""
self.model_name = embedder_config.model_name
self.device = embedder_config.device
self.batch_size = embedder_config.batch_size
self.max_length = embedder_config.max_length
self.use_cache = embedder_config.use_cache
self.embedding_config = embedder_config
self.embedding_model = SentenceTransformer(
self.model_name, device=self.device, prompts=embedder_config.get_prompt_config()
)
self.logger = logging.getLogger(__name__)
def __hash__(self) -> int:
"""
Compute a hash value for the Embedder.
:returns: The hash value of the Embedder.
"""
hasher = Hasher()
for parameter in self.embedding_model.parameters():
hasher.update(parameter.detach().cpu().numpy())
hasher.update(self.max_length)
return hasher.intdigest()
def clear_ram(self) -> None:
"""Move the embedding model to CPU and delete it from memory."""
self.logger.debug("Clearing embedder %s from memory", self.model_name)
self.embedding_model.cpu()
del self.embedding_model
torch.cuda.empty_cache()
def delete(self) -> None:
"""Delete the embedding model and its associated directory."""
self.clear_ram()
if self.dump_dir is not None:
shutil.rmtree(self.dump_dir)
def dump(self, path: Path) -> None:
"""
Save the embedding model and metadata to disk.
:param path: Path to the directory where the model will be saved.
"""
self.dump_dir = path
metadata = EmbedderDumpMetadata(
model_name_or_path=str(self.model_name),
device=self.device,
batch_size=self.batch_size,
max_length=self.max_length,
use_cache=self.use_cache,
)
path.mkdir(parents=True, exist_ok=True)
with (path / self.metadata_dict_name).open("w") as file:
json.dump(metadata, file, indent=4)
@classmethod
def load(cls, path: Path | str) -> "Embedder":
"""
Load the embedding model and metadata from disk.
:param path: Path to the directory where the model is stored.
"""
with (Path(path) / cls.metadata_dict_name).open() as file:
metadata: EmbedderDumpMetadata = json.load(file)
return cls(
EmbedderConfig(
model_name=metadata["model_name_or_path"],
device=metadata["device"],
batch_size=metadata["batch_size"],
max_length=metadata["max_length"],
use_cache=metadata["use_cache"],
)
)
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
"""
Calculate embeddings for a list of utterances.
:param utterances: List of input texts to calculate embeddings for.
:param task_type: Type of task for which embeddings are calculated.
:return: A numpy array of embeddings.
"""
if self.use_cache:
hasher = Hasher()
hasher.update(self)
hasher.update(utterances)
embeddings_path = get_embeddings_path(hasher.hexdigest())
if embeddings_path.exists():
return np.load(embeddings_path) # type: ignore[no-any-return]
self.logger.debug(
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s",
self.model_name,
self.batch_size,
str(self.max_length),
self.device,
)
if self.max_length is not None:
self.embedding_model.max_seq_length = self.max_length
embeddings = self.embedding_model.encode(
utterances,
convert_to_numpy=True,
batch_size=self.batch_size,
normalize_embeddings=True,
prompt_name=self.embedding_config.get_prompt_type(task_type),
)
if self.use_cache:
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
np.save(embeddings_path, embeddings)
return embeddings