Skip to content

Commit 4645706

Browse files
authored
Add models configs (#124)
* init * init * fix name * part fix loader * fix tests * fix tests * lint * fix typing * update docs * fix more docs * lint * remove commented code * fix config * fix last config * fix default configs * add dict type to init * fix docstring * lint * resolve part conversations * fix rerank scorer * add prompts to encode * add descriptions to schemes * fix tests * fix imports * fix import config
1 parent 0979234 commit 4645706

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+791
-829
lines changed

autointent/_datafiles/default-multiclass-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
search_space:
55
- module_name: retrieval
66
k: [10]
7-
embedder_name:
7+
embedder_config:
88
- avsolatorio/GIST-small-Embedding-v0
99
- sergeyzh/rubert-tiny-turbo
1010
- node_type: scoring
@@ -15,7 +15,7 @@
1515
weights: ["uniform", "distance", "closest"]
1616
- module_name: linear
1717
- module_name: dnnc
18-
cross_encoder_name:
18+
cross_encoder_config:
1919
- cross-encoder/ms-marco-MiniLM-L-6-v2
2020
k: [1, 3, 5, 10]
2121
- node_type: decision

autointent/_datafiles/default-multilabel-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
search_space:
55
- module_name: retrieval
66
k: [10]
7-
embedder_name:
7+
embedder_config:
88
- deepvk/USER-bge-m3
99
- node_type: scoring
1010
target_metric: scoring_roc_auc

autointent/_datafiles/inference-config-example.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
module_name: retrieval
33
module_config:
44
k: 10
5-
model_name: sergeyzh/rubert-tiny-turbo
5+
model_config: sergeyzh/rubert-tiny-turbo
66
load_path: .
77
- node_type: scoring
88
module_name: knn

autointent/_embedder.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sentence_transformers import SentenceTransformer
1818

1919
from ._hash import Hasher
20+
from .schemas import EmbedderConfig, TaskTypeEnum
2021

2122

2223
def get_embeddings_path(filename: str) -> Path:
@@ -40,7 +41,7 @@ class EmbedderDumpMetadata(TypedDict):
4041

4142
model_name_or_path: str
4243
"""Name of the hugging face model or a local path to sentence transformers dump."""
43-
device: str
44+
device: str | None
4445
"""Torch notation for CPU or CUDA."""
4546
batch_size: int
4647
"""Batch size used for embedding calculations."""
@@ -61,30 +62,22 @@ class Embedder:
6162
metadata_dict_name: str = "metadata.json"
6263
dump_dir: Path | None = None
6364

64-
def __init__(
65-
self,
66-
model_name_or_path: str | Path,
67-
device: str = "cpu",
68-
batch_size: int = 32,
69-
max_length: int | None = None,
70-
use_cache: bool = True,
71-
) -> None:
65+
def __init__(self, embedder_config: EmbedderConfig) -> None:
7266
"""
7367
Initialize the Embedder.
7468
75-
:param model_name_or_path: Path to a local model directory or a Hugging Face model name.
76-
:param device: Device to run the model on (e.g., "cpu", "cuda").
77-
:param batch_size: Batch size for embedding calculations.
78-
:param max_length: Maximum sequence length for the embedding model.
79-
:param use_cache: Flag indicating whether to cache intermediate embeddings.
69+
:param embedder_config: Config of embedder.
8070
"""
81-
self.model_name = model_name_or_path
82-
self.device = device
83-
self.batch_size = batch_size
84-
self.max_length = max_length
85-
self.use_cache = use_cache
86-
87-
self.embedding_model = SentenceTransformer(str(model_name_or_path), device=device)
71+
self.model_name = embedder_config.model_name
72+
self.device = embedder_config.device
73+
self.batch_size = embedder_config.batch_size
74+
self.max_length = embedder_config.max_length
75+
self.use_cache = embedder_config.use_cache
76+
self.embedding_config = embedder_config
77+
78+
self.embedding_model = SentenceTransformer(
79+
self.model_name, device=self.device, prompts=embedder_config.get_prompt_config()
80+
)
8881

8982
self.logger = logging.getLogger(__name__)
9083

@@ -132,9 +125,7 @@ def dump(self, path: Path) -> None:
132125
json.dump(metadata, file, indent=4)
133126

134127
@classmethod
135-
def load(
136-
cls, path: Path | str, batch_size: int | None = None, use_cache: bool | None = None, device: str | None = None
137-
) -> "Embedder":
128+
def load(cls, path: Path | str) -> "Embedder":
138129
"""
139130
Load the embedding model and metadata from disk.
140131
@@ -144,18 +135,21 @@ def load(
144135
metadata: EmbedderDumpMetadata = json.load(file)
145136

146137
return cls(
147-
model_name_or_path=metadata["model_name_or_path"],
148-
device=device or metadata["device"],
149-
batch_size=batch_size or metadata["batch_size"],
150-
max_length=metadata["max_length"],
151-
use_cache=use_cache or metadata["use_cache"],
138+
EmbedderConfig(
139+
model_name=metadata["model_name_or_path"],
140+
device=metadata["device"],
141+
batch_size=metadata["batch_size"],
142+
max_length=metadata["max_length"],
143+
use_cache=metadata["use_cache"],
144+
)
152145
)
153146

154-
def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
147+
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
155148
"""
156149
Calculate embeddings for a list of utterances.
157150
158151
:param utterances: List of input texts to calculate embeddings for.
152+
:param task_type: Type of task for which embeddings are calculated.
159153
:return: A numpy array of embeddings.
160154
"""
161155
if self.use_cache:
@@ -183,6 +177,7 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
183177
convert_to_numpy=True,
184178
batch_size=self.batch_size,
185179
normalize_embeddings=True,
180+
prompt_name=self.embedding_config.get_prompt_type(task_type),
186181
)
187182

188183
if self.use_cache:

autointent/_pipeline/_pipeline.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010

1111
from autointent import Context, Dataset
12-
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
12+
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1313
from autointent.custom_types import ListOfGenericLabels, NodeType, ValidationScheme
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
@@ -43,13 +43,11 @@ def __init__(
4343
if isinstance(nodes[0], NodeOptimizer):
4444
self.logging_config = LoggingConfig(dump_dir=None)
4545
self.vector_index_config = VectorIndexConfig()
46-
self.embedder_config = EmbedderConfig()
47-
self.cross_encoder_config = CrossEncoderConfig()
4846
elif not isinstance(nodes[0], InferenceNode):
4947
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
5048
raise TypeError(msg)
5149

52-
def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig | CrossEncoderConfig) -> None:
50+
def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
5351
"""
5452
Set configuration for the optimizer.
5553
@@ -59,10 +57,6 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig
5957
self.logging_config = config
6058
elif isinstance(config, VectorIndexConfig):
6159
self.vector_index_config = config
62-
elif isinstance(config, EmbedderConfig):
63-
self.embedder_config = config
64-
elif isinstance(config, CrossEncoderConfig):
65-
self.cross_encoder_config = config
6660
else:
6761
msg = "unknown config type"
6862
raise TypeError(msg)
@@ -138,8 +132,8 @@ def fit(
138132
context = Context()
139133
context.set_dataset(dataset, scheme, n_folds)
140134
context.configure_logging(self.logging_config)
141-
context.configure_vector_index(self.vector_index_config, self.embedder_config)
142-
context.configure_cross_encoder(self.cross_encoder_config)
135+
context.configure_vector_index(self.vector_index_config)
136+
143137
self.validate_modules(dataset)
144138
self._fit(context)
145139

autointent/_ranker.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
from torch import nn
2121

2222
from autointent.custom_types import ListOfLabels
23+
from autointent.schemas import CrossEncoderConfig
2324

2425
logger = logging.getLogger(__name__)
2526

2627

2728
class CrossEncoderMetadata(TypedDict):
2829
model_name: str
2930
train_classifier: bool
30-
device: str
31+
device: str | None
3132
max_length: int | None
3233
batch_size: int
3334

@@ -105,32 +106,27 @@ class Ranker:
105106

106107
def __init__(
107108
self,
108-
model_name: str,
109-
device: str = "cpu",
110-
train_classifier: bool = False,
111-
batch_size: int = 326,
112-
max_length: int | None = None,
109+
cross_encoder_config: CrossEncoderConfig | str | dict[str, Any],
113110
classifier_head: LogisticRegressionCV | None = None,
114111
) -> None:
115112
"""
116113
Initialize the Ranker.
117114
118-
:param model: The cross-encoder hugging face model name to use.
119-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
120-
:param train_classifier: Whether to train a custom classifier, defaults to False.
121-
:param batch_size: Batch size for processing text pairs, defaults to 326.
115+
:param cross_encoder_config: Config of the cross-encoder hugging face model name to use.
122116
:param max_length (int, optional): Max length for input sequences for the cross encoder.
123117
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
124118
"""
125-
self.model_name = model_name
126-
self.device = device
127-
self.cross_encoder = st.CrossEncoder(model_name, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
119+
self.cross_encoder_config = CrossEncoderConfig.from_search_config(cross_encoder_config)
120+
self.cross_encoder = st.CrossEncoder(
121+
self.cross_encoder_config.model_name,
122+
trust_remote_code=True,
123+
device=self.cross_encoder_config.device,
124+
max_length=self.cross_encoder_config.max_length, # type: ignore[arg-type]
125+
)
128126
self.train_classifier = False
129-
self.batch_size = batch_size
130-
self.max_length = max_length
131127
self._clf = classifier_head
132128

133-
if classifier_head is not None or train_classifier:
129+
if classifier_head is not None or self.cross_encoder_config.train_head:
134130
self.train_classifier = True
135131
self._activations_list: list[npt.NDArray[Any]] = []
136132
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)
@@ -150,10 +146,16 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
150146
:return: Numpy array of extracted features.
151147
"""
152148
if not self.train_classifier:
153-
return np.array(self.cross_encoder.predict(pairs, batch_size=self.batch_size, activation_fct=nn.Sigmoid()))
149+
return np.array(
150+
self.cross_encoder.predict(
151+
pairs,
152+
batch_size=self.cross_encoder_config.batch_size,
153+
activation_fct=nn.Sigmoid(),
154+
)
155+
)
154156

155157
# put the data through, features will be taken in the hook
156-
self.cross_encoder.predict(pairs, batch_size=self.batch_size)
158+
self.cross_encoder.predict(pairs, batch_size=self.cross_encoder_config.batch_size)
157159

158160
res = np.concatenate(self._activations_list, axis=0)
159161
self._activations_list.clear()
@@ -223,8 +225,8 @@ def rank(
223225
Rank documents according to meaning closeness to the query.
224226
225227
:param query: The reference document.
226-
:query_docs: List of documents to rank
227-
:top_k: how many document to return
228+
:param query_docs: List of documents to rank
229+
:param top_k: how many document to return
228230
:return: array of dictionaries of ranked items.
229231
"""
230232
query_doc_pairs = [(query, doc) for doc in query_docs]
@@ -247,11 +249,11 @@ def save(self, path: str) -> None:
247249
dump_dir.mkdir(parents=True)
248250

249251
metadata = CrossEncoderMetadata(
250-
model_name=self.model_name,
252+
model_name=self.cross_encoder_config.model_name,
251253
train_classifier=self.train_classifier,
252-
device=self.device,
253-
max_length=self.max_length,
254-
batch_size=self.batch_size,
254+
device=self.cross_encoder_config.device,
255+
max_length=self.cross_encoder_config.max_length,
256+
batch_size=self.cross_encoder_config.batch_size,
255257
)
256258

257259
with (dump_dir / self.metadata_file_name).open("w") as file:
@@ -272,7 +274,16 @@ def load(cls, path: Path) -> "Ranker":
272274
with (path / cls.metadata_file_name).open() as file:
273275
metadata: CrossEncoderMetadata = json.load(file)
274276

275-
return cls(**metadata, classifier_head=clf)
277+
return cls(
278+
CrossEncoderConfig(
279+
model_name=metadata["model_name"],
280+
device=metadata["device"],
281+
max_length=metadata["max_length"],
282+
batch_size=metadata["batch_size"],
283+
train_head=metadata["train_classifier"],
284+
),
285+
classifier_head=clf,
286+
)
276287

277288
def clear_ram(self) -> None:
278289
self.cross_encoder.model.cpu()

autointent/_vector_index.py

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
from autointent import Embedder
1818
from autointent.custom_types import ListOfLabels
19+
from autointent.schemas import EmbedderConfig, TaskTypeEnum
1920

2021

2122
class VectorIndexMetadata(TypedDict):
2223
embedder_model_name: str
23-
embedder_device: str
24+
embedder_device: str | None
2425
embedder_batch_size: int
2526
embedder_max_length: int | None
2627
embedder_use_cache: bool
@@ -42,31 +43,13 @@ class VectorIndex:
4243
_data_file = "data.json"
4344
_meta_data_file = "metadata.json"
4445

45-
def __init__(
46-
self,
47-
embedder_model_name: str,
48-
embedder_device: str,
49-
embedder_batch_size: int = 32,
50-
embedder_max_length: int | None = None,
51-
embedder_use_cache: bool = True,
52-
) -> None:
46+
def __init__(self, embedder_config: EmbedderConfig) -> None:
5347
"""
5448
Initialize the vector index.
5549
56-
:param embedder_model_name: Name of the embedding model to use.
57-
:param embedder_device: Device for running the embedding model (e.g., "cpu", "cuda").
58-
:param embedder_batch_size: Batch size for the embedder.
59-
:param embedder_max_length: Maximum sequence length for the embedder.
60-
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
61-
"""
62-
self.embedder = Embedder(
63-
model_name_or_path=embedder_model_name,
64-
batch_size=embedder_batch_size,
65-
device=embedder_device,
66-
max_length=embedder_max_length,
67-
use_cache=embedder_use_cache,
68-
)
69-
self.embedder_device = embedder_device
50+
:param embedder_config: Config of the embedding model to use.
51+
"""
52+
self.embedder = Embedder(embedder_config)
7053

7154
self.labels: ListOfLabels = [] # (n_samples,) or (n_samples, n_classes)
7255
self.texts: list[str] = []
@@ -81,7 +64,7 @@ def add(self, texts: list[str], labels: ListOfLabels) -> None:
8164
:param labels: List of labels corresponding to the texts.
8265
"""
8366
self.logger.debug("Adding embeddings to vector index %s", self.embedder.model_name)
84-
embeddings = self.embedder.embed(texts)
67+
embeddings = self.embedder.embed(texts, TaskTypeEnum.passage)
8568

8669
if not hasattr(self, "index"):
8770
self.index = faiss.IndexFlatIP(embeddings.shape[1])
@@ -120,7 +103,7 @@ def _search_by_text(self, texts: list[str], k: int) -> list[list[dict[str, Any]]
120103
:param k: Number of nearest neighbors to return.
121104
:return: List of search results for each query.
122105
"""
123-
query_embedding: npt.NDArray[np.float64] = self.embedder.embed(texts) # type: ignore[assignment]
106+
query_embedding: npt.NDArray[np.float64] = self.embedder.embed(texts, TaskTypeEnum.query) # type: ignore[assignment]
124107
return self._search_by_embedding(query_embedding, k)
125108

126109
def _search_by_embedding(self, embedding: npt.NDArray[Any], k: int) -> list[list[dict[str, Any]]]:
@@ -233,11 +216,13 @@ def load(
233216
metadata: VectorIndexMetadata = json.load(file)
234217

235218
instance = cls(
236-
embedder_model_name=metadata["embedder_model_name"],
237-
embedder_device=embedder_device or metadata["embedder_device"],
238-
embedder_batch_size=embedder_batch_size or metadata["embedder_batch_size"],
239-
embedder_max_length=metadata["embedder_max_length"],
240-
embedder_use_cache=embedder_use_cache or metadata["embedder_use_cache"],
219+
EmbedderConfig(
220+
model_name=metadata["embedder_model_name"],
221+
device=embedder_device or metadata["embedder_device"],
222+
batch_size=embedder_batch_size or metadata["embedder_batch_size"],
223+
max_length=metadata["embedder_max_length"],
224+
use_cache=embedder_use_cache or metadata["embedder_use_cache"],
225+
)
241226
)
242227

243228
with (dir_path / cls._data_file).open() as file:

0 commit comments

Comments
 (0)