Skip to content

Commit 36214e5

Browse files
Darinochkavoorhs
andauthored
Feat/logreg retrieval (#81)
* feat: added logregretrieval * fix: predict for logregretrieval * fix: predict for logregretrieval * fix: added kwargs * fixed: dump & load modules and added tests * fix: fixed dump for RetrieverEmbedding * fix: fixed docstring * fix: change vector_index to embedder * fix: change to the ScoringMetricFn * fix: multilabel and fix scorer metric * fix: load and dump * fix: lint * fix: lint * fix: mypy * fix: docs * fix: docs * feat: change predict in RetrievalEmbedding * feat: change predict in RetrievalEmbedding * feat: update logregembedding * feat: update docstring * fix: fixed retrieval test * fix: fixed retrieval and logreg test * fix: added cv to the docs example * fix: fixed score func * fix: added accuracy for scorer in logreg * fix: added predict_proba * test: update tests * feat: divide retrieval and logreg * fix: fixed setup_environment * fix: fixed import * fix: deleted dump and load * fix: rename classifier and label encoder * fix: fixed multilabel * feat: updated tests * feat: update multiclass.yaml * fix: added cv * fix: lint * fix: fixed metric in multilabel.yaml * fix: fixed _classifier * fix: fixed label encoder * fix: fixed scoring * fix: fixed split in score * fix: fixed split in score * fix: type * feat: updated predict() in logreg * feat: updated test * fix: fixed lint * fix: no-any return * make changes * fix * remove `k` completely * remove k from search space * fix another `k` issue * finally? --------- Co-authored-by: voorhs <[email protected]>
1 parent 1ff18cf commit 36214e5

File tree

11 files changed

+336
-61
lines changed

11 files changed

+336
-61
lines changed

autointent/modules/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ThresholdDecision,
1111
TunableDecision,
1212
)
13-
from .embedding import RetrievalEmbedding
13+
from .embedding import LogregAimedEmbedding, RetrievalAimedEmbedding
1414
from .scoring import DescriptionScorer, DNNCScorer, KNNScorer, LinearScorer, MLKnnScorer, RerankScorer, SklearnScorer
1515

1616
T = TypeVar("T", bound=Module)
@@ -20,7 +20,9 @@ def _create_modules_dict(modules: list[type[T]]) -> dict[str, type[T]]:
2020
return {module.name: module for module in modules}
2121

2222

23-
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[EmbeddingModule]] = _create_modules_dict([RetrievalEmbedding])
23+
RETRIEVAL_MODULES_MULTICLASS: dict[str, type[EmbeddingModule]] = _create_modules_dict(
24+
[RetrievalAimedEmbedding, LogregAimedEmbedding]
25+
)
2426

2527
RETRIEVAL_MODULES_MULTILABEL = RETRIEVAL_MODULES_MULTICLASS
2628

autointent/modules/abc/_embedding.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,3 @@
77

88
class EmbeddingModule(Module, ABC):
99
"""Base class for embedding modules."""
10-
11-
def __init__(self, k: int) -> None:
12-
"""
13-
Initialize embedding module.
14-
15-
:param k: number of closest neighbors to consider during inference
16-
"""
17-
self.k = k
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""These modules are used only for optimization as they use proxy metrics for choosing best embedding model."""
22

3-
from ._retrieval import RetrievalEmbedding
3+
from ._logreg import LogregAimedEmbedding
4+
from ._retrieval import RetrievalAimedEmbedding
45

5-
__all__ = ["RetrievalEmbedding"]
6+
__all__ = ["LogregAimedEmbedding", "RetrievalAimedEmbedding"]
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
"""LogregAimedEmbedding class for a proxy optimzation of embedding."""
2+
3+
from typing import Literal
4+
5+
import numpy as np
6+
from numpy.typing import NDArray
7+
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
8+
from sklearn.multioutput import MultiOutputClassifier
9+
from sklearn.preprocessing import LabelEncoder
10+
11+
from autointent import Context, Embedder
12+
from autointent.context.optimization_info import RetrieverArtifact
13+
from autointent.custom_types import ListOfLabels
14+
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
15+
from autointent.modules.abc import EmbeddingModule
16+
17+
18+
class LogregAimedEmbedding(EmbeddingModule):
19+
r"""
20+
Module for configuring embeddings optimized for linear classification.
21+
22+
The main purpose of this module is to be used at embedding node for optimizing
23+
embedding configuration using its logreg classification quality as a sort of proxy metric.
24+
25+
:ivar classifier: The trained logistic regression model.
26+
:ivar label_encoder: Label encoder for converting labels to numerical format.
27+
:ivar name: Name of the module, defaults to "logreg".
28+
29+
Examples
30+
--------
31+
.. testcode::
32+
33+
from autointent.modules.embedding import LogregAimedEmbedding
34+
utterances = ["bye", "how are you?", "good morning"]
35+
labels = [0, 1, 1]
36+
retrieval = LogregAimedEmbedding(
37+
embedder_name="sergeyzh/rubert-tiny-turbo",
38+
cv=2
39+
)
40+
retrieval.fit(utterances, labels)
41+
"""
42+
43+
_classifier: LogisticRegressionCV | MultiOutputClassifier
44+
_label_encoder: LabelEncoder | None
45+
name = "logreg"
46+
supports_multiclass = True
47+
supports_multilabel = True
48+
supports_oos = False
49+
50+
def __init__(
51+
self,
52+
embedder_name: str,
53+
cv: int = 3,
54+
embedder_device: str = "cpu",
55+
embedder_batch_size: int = 32,
56+
embedder_max_length: int | None = None,
57+
embedder_use_cache: bool = True,
58+
) -> None:
59+
"""
60+
Initialize the LogregAimedEmbedding.
61+
62+
:param cv: the number of folds used in LogisticRegressionCV
63+
:param embedder_name: Name of the embedder used for creating embeddings.
64+
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
65+
:param batch_size: Batch size for embedding generation.
66+
:param max_length: Maximum sequence length for embeddings. None if not set.
67+
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
68+
"""
69+
self.embedder_name = embedder_name
70+
self.embedder_device = embedder_device
71+
self.embedder_batch_size = embedder_batch_size
72+
self.embedder_max_length = embedder_max_length
73+
self.embedder_use_cache = embedder_use_cache
74+
self.cv = cv
75+
76+
@classmethod
77+
def from_context(
78+
cls,
79+
context: Context,
80+
cv: int,
81+
embedder_name: str,
82+
) -> "LogregAimedEmbedding":
83+
"""
84+
Create a LogregAimedEmbedding instance using a Context object.
85+
86+
:param cv: the number of folds used in LogisticRegressionCV
87+
:param context: The context containing configurations and utilities.
88+
:param embedder_name: Name of the embedder to use.
89+
:return: Initialized LogregAimedEmbedding instance.
90+
"""
91+
return cls(
92+
cv=cv,
93+
embedder_name=embedder_name,
94+
embedder_device=context.get_device(),
95+
embedder_batch_size=context.get_batch_size(),
96+
embedder_max_length=context.get_max_length(),
97+
embedder_use_cache=context.get_use_cache(),
98+
)
99+
100+
def clear_cache(self) -> None:
101+
pass
102+
103+
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
104+
"""
105+
Train the logistic regression model using the provided utterances and labels.
106+
107+
:param utterances: List of text data to index.
108+
:param labels: List of corresponding labels for the utterances.
109+
"""
110+
self._validate_task(labels)
111+
112+
self._embedder = Embedder(
113+
device=self.embedder_device,
114+
model_name_or_path=self.embedder_name,
115+
batch_size=self.embedder_batch_size,
116+
max_length=self.embedder_max_length,
117+
use_cache=self.embedder_use_cache,
118+
)
119+
embeddings = self._embedder.embed(utterances)
120+
121+
if self._multilabel:
122+
self._label_encoder = None
123+
base_clf = LogisticRegression()
124+
self._classifier = MultiOutputClassifier(base_clf)
125+
else:
126+
self._label_encoder = LabelEncoder()
127+
labels = self._label_encoder.fit_transform(labels)
128+
self._classifier = LogisticRegressionCV(cv=self.cv)
129+
130+
self._classifier.fit(embeddings, labels)
131+
132+
def score(
133+
self,
134+
context: Context,
135+
split: Literal["validation", "test"],
136+
) -> dict[str, float | str]:
137+
"""
138+
Evaluate the embedding model using a specified metric function.
139+
140+
:param context: The context containing test data and labels.
141+
:param split: Target split
142+
:return: Computed metrics value for the test set or error code of metrics
143+
"""
144+
if split == "validation":
145+
utterances = context.data_handler.validation_utterances(0)
146+
labels = context.data_handler.validation_labels(0)
147+
elif split == "test":
148+
utterances = context.data_handler.test_utterances()
149+
labels = context.data_handler.test_labels()
150+
else:
151+
message = f"Invalid split '{split}' provided. Expected one of 'validation', or 'test'."
152+
raise ValueError(message)
153+
154+
probas = self.predict(utterances)
155+
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
156+
return self.score_metrics((labels, probas), metrics_dict)
157+
158+
def get_assets(self) -> RetrieverArtifact:
159+
"""
160+
Get the classifier artifacts for this module.
161+
162+
:return: A RetrieverArtifact object containing embedder information.
163+
"""
164+
return RetrieverArtifact(embedder_name=self.embedder_name)
165+
166+
def predict(self, utterances: list[str]) -> NDArray[np.float64]:
167+
embeddings = self._embedder.embed(utterances)
168+
probas = self._classifier.predict_proba(embeddings)
169+
170+
if self._multilabel:
171+
probas = np.stack(probas, axis=1)[..., 1]
172+
173+
return probas # type: ignore[no-any-return]

autointent/modules/embedding/_retrieval.py

Lines changed: 18 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
"""RetrievalEmbedding class for managing and interacting with a vector database for retrieval tasks."""
1+
"""RetrievalAimedEmbedding class for a proxy optimization of embedding."""
22

3-
from pathlib import Path
43
from typing import Literal
54

6-
from autointent import VectorIndex
7-
from autointent.context import Context
5+
from autointent import Context, VectorIndex
86
from autointent.context.optimization_info import RetrieverArtifact
97
from autointent.custom_types import ListOfLabels
108
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
119
from autointent.modules.abc import EmbeddingModule
1210

1311

14-
class RetrievalEmbedding(EmbeddingModule):
12+
class RetrievalAimedEmbedding(EmbeddingModule):
1513
r"""
16-
Module for managing retrieval operations using a vector database.
14+
Module for configuring embeddings optimized for retrieval tasks.
1715
18-
RetrievalEmbedding provides methods for indexing, querying, and managing a vector database for tasks
19-
such as nearest neighbor retrieval.
16+
The main purpose of this module is to be used at embedding node for optimizing
17+
embedding configuration using its retrieval quality as a sort of proxy metric.
2018
2119
:ivar vector_index: The vector index used for nearest neighbor retrieval.
2220
:ivar name: Name of the module, defaults to "retrieval".
@@ -26,25 +24,22 @@ class RetrievalEmbedding(EmbeddingModule):
2624
2725
.. testcode::
2826
29-
from autointent.modules.embedding import RetrievalEmbedding
27+
from autointent.modules.embedding import RetrievalAimedEmbedding
3028
utterances = ["bye", "how are you?", "good morning"]
3129
labels = [0, 1, 1]
32-
retrieval = RetrievalEmbedding(
30+
retrieval = RetrievalAimedEmbedding(
3331
k=2,
3432
embedder_name="sergeyzh/rubert-tiny-turbo",
3533
)
3634
retrieval.fit(utterances, labels)
37-
predictions = retrieval.predict(["how is the weather today?"])
38-
print(predictions)
39-
40-
.. testoutput::
41-
42-
([[1, 1]], [[0.1525942087173462, 0.18616724014282227]], [['good morning', 'how are you?']])
4335
4436
"""
4537

4638
_vector_index: VectorIndex
4739
name = "retrieval"
40+
supports_multiclass = True
41+
supports_multilabel = True
42+
supports_oos = False
4843

4944
def __init__(
5045
self,
@@ -56,7 +51,7 @@ def __init__(
5651
embedder_use_cache: bool = True,
5752
) -> None:
5853
"""
59-
Initialize the RetrievalEmbedding.
54+
Initialize the RetrievalAimedEmbedding.
6055
6156
:param k: Number of nearest neighbors to retrieve.
6257
:param embedder_name: Name of the embedder used for creating embeddings.
@@ -65,28 +60,27 @@ def __init__(
6560
:param max_length: Maximum sequence length for embeddings. None if not set.
6661
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
6762
"""
63+
self.k = k
6864
self.embedder_name = embedder_name
6965
self.embedder_device = embedder_device
7066
self.embedder_batch_size = embedder_batch_size
7167
self.embedder_max_length = embedder_max_length
7268
self.embedder_use_cache = embedder_use_cache
7369

74-
super().__init__(k=k)
75-
7670
@classmethod
7771
def from_context(
7872
cls,
7973
context: Context,
8074
k: int,
8175
embedder_name: str,
82-
) -> "RetrievalEmbedding":
76+
) -> "RetrievalAimedEmbedding":
8377
"""
84-
Create a RetrievalEmbedding instance using a Context object.
78+
Create an instance using a Context object.
8579
8680
:param context: The context containing configurations and utilities.
8781
:param k: Number of nearest neighbors to retrieve.
8882
:param embedder_name: Name of the embedder to use.
89-
:return: Initialized RetrievalEmbedding instance.
83+
:return: Initialized RetrievalAimedEmbedding instance.
9084
"""
9185
return cls(
9286
k=k,
@@ -104,6 +98,8 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
10498
:param utterances: List of text data to index.
10599
:param labels: List of corresponding labels for the utterances.
106100
"""
101+
self._validate_task(labels)
102+
107103
self._vector_index = VectorIndex(
108104
self.embedder_name,
109105
self.embedder_device,
@@ -151,22 +147,6 @@ def clear_cache(self) -> None:
151147
"""Clear cached data in memory used by the vector index."""
152148
self._vector_index.clear_ram()
153149

154-
def dump(self, path: str) -> None:
155-
"""
156-
Save the module's metadata and vector index to a specified directory.
157-
158-
:param path: Path to the directory where assets will be dumped.
159-
"""
160-
self._vector_index.dump(Path(path))
161-
162-
def load(self, path: str) -> None:
163-
"""
164-
Load the module's metadata and vector index from a specified directory.
165-
166-
:param path: Path to the directory containing the dumped assets.
167-
"""
168-
self._vector_index = VectorIndex.load(Path(path))
169-
170150
def predict(self, utterances: list[str]) -> tuple[list[ListOfLabels], list[list[float]], list[list[str]]]:
171151
"""
172152
Predict the nearest neighbors for a list of utterances.

tests/assets/configs/multilabel.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
- node_type: embedding
2-
metric: retrieval_hit_rate_intersecting
2+
metric: scoring_accuracy
33
search_space:
4-
- module_name: retrieval
5-
k: [10]
4+
- module_name: logreg
5+
cv: [2]
66
embedder_name:
77
- sentence-transformers/all-MiniLM-L6-v2
88
- avsolatorio/GIST-small-Embedding-v0
@@ -33,4 +33,4 @@
3333
- module_name: threshold
3434
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
3535
- module_name: tunable
36-
- module_name: adaptive
36+
- module_name: adaptive

0 commit comments

Comments
 (0)