Skip to content

Commit 8902fc1

Browse files
committed
init
1 parent 15016a5 commit 8902fc1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+600
-674
lines changed

autointent/_embedder.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sentence_transformers import SentenceTransformer
1818

1919
from ._hash import Hasher
20+
from .schemas._schemas import EmbedderConfig
2021

2122

2223
def get_embeddings_path(filename: str) -> Path:
@@ -40,7 +41,7 @@ class EmbedderDumpMetadata(TypedDict):
4041

4142
model_name_or_path: str
4243
"""Name of the hugging face model or a local path to sentence transformers dump."""
43-
device: str
44+
device: str | None
4445
"""Torch notation for CPU or CUDA."""
4546
batch_size: int
4647
"""Batch size used for embedding calculations."""
@@ -61,30 +62,19 @@ class Embedder:
6162
metadata_dict_name: str = "metadata.json"
6263
dump_dir: Path | None = None
6364

64-
def __init__(
65-
self,
66-
model_name_or_path: str | Path,
67-
device: str = "cpu",
68-
batch_size: int = 32,
69-
max_length: int | None = None,
70-
use_cache: bool = True,
71-
) -> None:
65+
def __init__(self, embedder_config: EmbedderConfig) -> None:
7266
"""
7367
Initialize the Embedder.
7468
75-
:param model_name_or_path: Path to a local model directory or a Hugging Face model name.
76-
:param device: Device to run the model on (e.g., "cpu", "cuda").
77-
:param batch_size: Batch size for embedding calculations.
78-
:param max_length: Maximum sequence length for the embedding model.
79-
:param use_cache: Flag indicating whether to cache intermediate embeddings.
69+
:param embedder_config: Path to a local model directory or a Hugging Face model name.
8070
"""
81-
self.model_name = model_name_or_path
82-
self.device = device
83-
self.batch_size = batch_size
84-
self.max_length = max_length
85-
self.use_cache = use_cache
71+
self.model_name = embedder_config.model_name
72+
self.device = embedder_config.device
73+
self.batch_size = embedder_config.batch_size
74+
self.max_length = embedder_config.max_length
75+
self.use_cache = embedder_config.use_cache
8676

87-
self.embedding_model = SentenceTransformer(str(model_name_or_path), device=device)
77+
self.embedding_model = SentenceTransformer(self.model_name, device=self.device)
8878

8979
self.logger = logging.getLogger(__name__)
9080

@@ -132,9 +122,7 @@ def dump(self, path: Path) -> None:
132122
json.dump(metadata, file, indent=4)
133123

134124
@classmethod
135-
def load(
136-
cls, path: Path | str, batch_size: int | None = None, use_cache: bool | None = None, device: str | None = None
137-
) -> "Embedder":
125+
def load(cls, path: Path | str) -> "Embedder":
138126
"""
139127
Load the embedding model and metadata from disk.
140128
@@ -144,11 +132,13 @@ def load(
144132
metadata: EmbedderDumpMetadata = json.load(file)
145133

146134
return cls(
147-
model_name_or_path=metadata["model_name_or_path"],
148-
device=device or metadata["device"],
149-
batch_size=batch_size or metadata["batch_size"],
150-
max_length=metadata["max_length"],
151-
use_cache=use_cache or metadata["use_cache"],
135+
EmbedderConfig(
136+
model_name=metadata["model_name_or_path"],
137+
device=metadata["device"],
138+
batch_size=metadata["batch_size"],
139+
max_length=metadata["max_length"],
140+
use_cache=metadata["use_cache"],
141+
)
152142
)
153143

154144
def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
@@ -189,4 +179,4 @@ def embed(self, utterances: list[str]) -> npt.NDArray[np.float32]:
189179
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
190180
np.save(embeddings_path, embeddings)
191181

192-
return embeddings
182+
return embeddings # type: ignore[return-value]

autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import yaml
1010

1111
from autointent import Context, Dataset
12-
from autointent.configs import CrossEncoderConfig, EmbedderConfig, InferenceNodeConfig, LoggingConfig, VectorIndexConfig
12+
from autointent.configs import InferenceNodeConfig, LoggingConfig, VectorIndexConfig
1313
from autointent.custom_types import ListOfGenericLabels, NodeType
1414
from autointent.metrics import PREDICTION_METRICS_MULTILABEL
1515
from autointent.nodes import InferenceNode, NodeOptimizer
@@ -43,13 +43,11 @@ def __init__(
4343
if isinstance(nodes[0], NodeOptimizer):
4444
self.logging_config = LoggingConfig(dump_dir=None)
4545
self.vector_index_config = VectorIndexConfig()
46-
self.embedder_config = EmbedderConfig()
47-
self.cross_encoder_config = CrossEncoderConfig()
4846
elif not isinstance(nodes[0], InferenceNode):
4947
msg = "Pipeline should be initialized with list of NodeOptimizers or InferenceNodes"
5048
raise TypeError(msg)
5149

52-
def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig | CrossEncoderConfig) -> None:
50+
def set_config(self, config: LoggingConfig | VectorIndexConfig) -> None:
5351
"""
5452
Set configuration for the optimizer.
5553
@@ -59,10 +57,6 @@ def set_config(self, config: LoggingConfig | VectorIndexConfig | EmbedderConfig
5957
self.logging_config = config
6058
elif isinstance(config, VectorIndexConfig):
6159
self.vector_index_config = config
62-
elif isinstance(config, EmbedderConfig):
63-
self.embedder_config = config
64-
elif isinstance(config, CrossEncoderConfig):
65-
self.cross_encoder_config = config
6660
else:
6761
msg = "unknown config type"
6862
raise TypeError(msg)
@@ -136,8 +130,6 @@ def fit(self, dataset: Dataset) -> Context:
136130
context = Context()
137131
context.set_dataset(dataset)
138132
context.configure_logging(self.logging_config)
139-
context.configure_vector_index(self.vector_index_config, self.embedder_config)
140-
context.configure_cross_encoder(self.cross_encoder_config)
141133

142134
self._fit(context)
143135

autointent/_ranker.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@
1919
from torch import nn
2020

2121
from autointent.custom_types import ListOfLabels
22+
from autointent.schemas._schemas import CrossEncoderConfig
2223

2324
logger = logging.getLogger(__name__)
2425

2526

2627
class CrossEncoderMetadata(TypedDict):
2728
model_name: str
2829
train_classifier: bool
29-
device: str
30+
device: str | None
3031
max_length: int | None
3132
batch_size: int
3233

@@ -104,32 +105,27 @@ class Ranker:
104105

105106
def __init__(
106107
self,
107-
model_name: str,
108-
device: str = "cpu",
109-
train_classifier: bool = False,
110-
batch_size: int = 326,
111-
max_length: int | None = None,
108+
cross_encoder_config: CrossEncoderConfig,
112109
classifier_head: LogisticRegressionCV | None = None,
113110
) -> None:
114111
"""
115112
Initialize the Ranker.
116113
117-
:param model: The cross-encoder hugging face model name to use.
118-
:param device: Device to run operations on, e.g., "cpu" or "cuda".
119-
:param train_classifier: Whether to train a custom classifier, defaults to False.
120-
:param batch_size: Batch size for processing text pairs, defaults to 326.
114+
:param cross_encoder_config: The cross-encoder hugging face model name to use.
121115
:param max_length (int, optional): Max length for input sequences for the cross encoder.
122116
:param classifier_head (LogisticRegressionCV, optional): Classifier (to be used in restore procedure mainly).
123117
"""
124-
self.model_name = model_name
125-
self.device = device
126-
self.cross_encoder = st.CrossEncoder(model_name, trust_remote_code=True, device=device, max_length=max_length) # type: ignore[arg-type]
118+
self.cross_encoder = st.CrossEncoder(
119+
cross_encoder_config.model_name,
120+
trust_remote_code=True,
121+
device=cross_encoder_config.device, # type: ignore[arg-type]
122+
max_length=cross_encoder_config.max_length, # type: ignore[arg-type]
123+
)
127124
self.train_classifier = False
128-
self.batch_size = batch_size
129-
self.max_length = max_length
130125
self._clf = classifier_head
126+
self.cross_encoder_config = cross_encoder_config
131127

132-
if classifier_head is not None or train_classifier:
128+
if classifier_head is not None or cross_encoder_config.train_head:
133129
self.train_classifier = True
134130
self._activations_list: list[npt.NDArray[Any]] = []
135131
self._hook_handler = self.cross_encoder.model.classifier.register_forward_hook(self._classifier_hook)
@@ -149,10 +145,14 @@ def _get_features_or_predictions(self, pairs: list[tuple[str, str]]) -> npt.NDAr
149145
:return: Numpy array of extracted features.
150146
"""
151147
if not self.train_classifier:
152-
return np.array(self.cross_encoder.predict(pairs, batch_size=self.batch_size, activation_fct=nn.Sigmoid()))
148+
return np.array(
149+
self.cross_encoder.predict(
150+
pairs, batch_size=self.cross_encoder_config.batch_size, activation_fct=nn.Sigmoid() # type: ignore[arg-type]
151+
)
152+
)
153153

154154
# put the data through, features will be taken in the hook
155-
self.cross_encoder.predict(pairs, batch_size=self.batch_size)
155+
self.cross_encoder.predict(pairs, batch_size=self.cross_encoder_config.batch_size) # type: ignore[arg-type]
156156

157157
res = np.concatenate(self._activations_list, axis=0)
158158
self._activations_list.clear()
@@ -222,8 +222,8 @@ def rank(
222222
Rank documents according to meaning closeness to the query.
223223
224224
:param query: The reference document.
225-
:query_docs: List of documents to rank
226-
:top_k: how many document to return
225+
:param query_docs: List of documents to rank
226+
:param top_k: how many document to return
227227
:return: array of dictionaries of ranked items.
228228
"""
229229
query_doc_pairs = [(query, doc) for doc in query_docs]
@@ -246,11 +246,11 @@ def save(self, path: str) -> None:
246246
dump_dir.mkdir(parents=True)
247247

248248
metadata = CrossEncoderMetadata(
249-
model_name=self.model_name,
249+
model_name=self.cross_encoder_config.model_name,
250250
train_classifier=self.train_classifier,
251-
device=self.device,
252-
max_length=self.max_length,
253-
batch_size=self.batch_size,
251+
device=self.cross_encoder_config.device,
252+
max_length=self.cross_encoder_config.max_length,
253+
batch_size=self.cross_encoder_config.batch_size,
254254
)
255255

256256
with (dump_dir / self.metadata_file_name).open("w") as file:
@@ -271,4 +271,13 @@ def load(cls, path: Path) -> "Ranker":
271271
with (path / cls.metadata_file_name).open() as file:
272272
metadata: CrossEncoderMetadata = json.load(file)
273273

274-
return cls(**metadata, classifier_head=clf)
274+
return cls(
275+
CrossEncoderConfig(
276+
model_name=metadata["model_name"],
277+
device=metadata["device"],
278+
max_length=metadata["max_length"],
279+
batch_size=metadata["batch_size"],
280+
train_head=metadata["train_classifier"],
281+
),
282+
classifier_head=clf,
283+
)

autointent/_vector_index.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
from autointent import Embedder
1818
from autointent.custom_types import ListOfLabels
19+
from autointent.schemas._schemas import EmbedderConfig
1920

2021

2122
class VectorIndexMetadata(TypedDict):
2223
embedder_model_name: str
23-
embedder_device: str
24+
embedder_device: str | None
2425
embedder_batch_size: int
2526
embedder_max_length: int | None
2627
embedder_use_cache: bool
@@ -42,31 +43,13 @@ class VectorIndex:
4243
_data_file = "data.json"
4344
_meta_data_file = "metadata.json"
4445

45-
def __init__(
46-
self,
47-
embedder_model_name: str,
48-
embedder_device: str,
49-
embedder_batch_size: int = 32,
50-
embedder_max_length: int | None = None,
51-
embedder_use_cache: bool = True,
52-
) -> None:
46+
def __init__(self, embedder_config: EmbedderConfig) -> None:
5347
"""
5448
Initialize the vector index.
5549
56-
:param embedder_model_name: Name of the embedding model to use.
57-
:param embedder_device: Device for running the embedding model (e.g., "cpu", "cuda").
58-
:param embedder_batch_size: Batch size for the embedder.
59-
:param embedder_max_length: Maximum sequence length for the embedder.
60-
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
61-
"""
62-
self.embedder = Embedder(
63-
model_name_or_path=embedder_model_name,
64-
batch_size=embedder_batch_size,
65-
device=embedder_device,
66-
max_length=embedder_max_length,
67-
use_cache=embedder_use_cache,
68-
)
69-
self.embedder_device = embedder_device
50+
:param embedder_config: Name of the embedding model to use.
51+
"""
52+
self.embedder = Embedder(embedder_config)
7053

7154
self.labels: ListOfLabels = [] # (n_samples,) or (n_samples, n_classes)
7255
self.texts: list[str] = []
@@ -233,11 +216,13 @@ def load(
233216
metadata: VectorIndexMetadata = json.load(file)
234217

235218
instance = cls(
236-
embedder_model_name=metadata["embedder_model_name"],
237-
embedder_device=embedder_device or metadata["embedder_device"],
238-
embedder_batch_size=embedder_batch_size or metadata["embedder_batch_size"],
239-
embedder_max_length=metadata["embedder_max_length"],
240-
embedder_use_cache=embedder_use_cache or metadata["embedder_use_cache"],
219+
EmbedderConfig(
220+
model_name=metadata["embedder_model_name"],
221+
device=embedder_device or metadata["embedder_device"],
222+
batch_size=embedder_batch_size or metadata["embedder_batch_size"],
223+
max_length=metadata["embedder_max_length"],
224+
use_cache=embedder_use_cache or metadata["embedder_use_cache"],
225+
)
241226
)
242227

243228
with (dir_path / cls._data_file).open() as file:

autointent/configs/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,14 @@
22

33
from ._inference_node import InferenceNodeConfig
44
from ._optimization import (
5-
CrossEncoderConfig,
65
DataConfig,
7-
EmbedderConfig,
86
LoggingConfig,
97
TaskConfig,
108
VectorIndexConfig,
119
)
1210

1311
__all__ = [
14-
"CrossEncoderConfig",
1512
"DataConfig",
16-
"EmbedderConfig",
1713
"InferenceNodeConfig",
1814
"InferenceNodeConfig",
1915
"LoggingConfig",

autointent/configs/_optimization.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -55,44 +55,3 @@ class VectorIndexConfig(BaseModel):
5555

5656
save_db: bool = False
5757
"""Whether to save the vector index database or not"""
58-
59-
60-
class TransformerConfig(BaseModel):
61-
"""
62-
Base class for configuration for the transformer.
63-
64-
Transformer is used under the hood in :py:class:`autointent.Embedder` and :py:class:`autointent.Ranker`.
65-
"""
66-
67-
batch_size: int = 32
68-
"""Batch size for the embedder"""
69-
max_length: int | None = None
70-
"""Max length for the embedder. If None, the max length will be taken from model config"""
71-
device: str = "cpu"
72-
"""Device to use for the vector index. Can be 'cpu', 'cuda', 'cuda:0', 'mps', etc."""
73-
74-
75-
class EmbedderConfig(TransformerConfig):
76-
"""
77-
Configuration for the embedder.
78-
79-
The embedder is used to embed the data before training the model. These parameters
80-
will be applied to the embedder used in the optimization process in vector db.
81-
Only one model can be used globally.
82-
"""
83-
84-
use_cache: bool = True
85-
"""Whether to cache embeddings for reuse, improving performance in repeated operations."""
86-
87-
88-
class CrossEncoderConfig(TransformerConfig):
89-
"""
90-
Configuration for the embedder.
91-
92-
The embedder is used to embed the data before training the model. These parameters
93-
will be applied to the embedder used in the optimization process in vector db.
94-
Only one model can be used globally.
95-
"""
96-
97-
train_head: bool = False
98-
"""Whether to train the ranking head of a cross encoder."""

0 commit comments

Comments
 (0)