-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path_retrieval.py
More file actions
160 lines (130 loc) · 6 KB
/
_retrieval.py
File metadata and controls
160 lines (130 loc) · 6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""RetrievalAimedEmbedding class for a proxy optimization of embedding."""
from pydantic import PositiveInt
from autointent import Context, VectorIndex
from autointent.context.optimization_info import RetrieverArtifact
from autointent.custom_types import ListOfLabels
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
from autointent.modules.abc import EmbeddingModule
class RetrievalAimedEmbedding(EmbeddingModule):
r"""
Module for configuring embeddings optimized for retrieval tasks.
The main purpose of this module is to be used at embedding node for optimizing
embedding configuration using its retrieval quality as a sort of proxy metric.
:ivar vector_index: The vector index used for nearest neighbor retrieval.
:ivar name: Name of the module, defaults to "retrieval".
Examples
--------
.. testcode::
from autointent.modules.embedding import RetrievalAimedEmbedding
utterances = ["bye", "how are you?", "good morning"]
labels = [0, 1, 1]
retrieval = RetrievalAimedEmbedding(
k=2,
embedder_name="sergeyzh/rubert-tiny-turbo",
)
retrieval.fit(utterances, labels)
"""
_vector_index: VectorIndex
name = "retrieval"
supports_multiclass = True
supports_multilabel = True
supports_oos = False
def __init__(
self,
k: PositiveInt,
embedder_name: str,
embedder_device: str = "cpu",
embedder_batch_size: int = 32,
embedder_max_length: int | None = None,
embedder_use_cache: bool = True,
) -> None:
"""
Initialize the RetrievalAimedEmbedding.
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder used for creating embeddings.
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
:param batch_size: Batch size for embedding generation.
:param max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
"""
self.k = k
self.embedder_name = embedder_name
self.embedder_device = embedder_device
self.embedder_batch_size = embedder_batch_size
self.embedder_max_length = embedder_max_length
self.embedder_use_cache = embedder_use_cache
@classmethod
def from_context(
cls,
context: Context,
k: PositiveInt,
embedder_name: str,
) -> "RetrievalAimedEmbedding":
"""
Create an instance using a Context object.
:param context: The context containing configurations and utilities.
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder to use.
:return: Initialized RetrievalAimedEmbedding instance.
"""
return cls(
k=k,
embedder_name=embedder_name,
embedder_device=context.get_device(),
embedder_batch_size=context.get_batch_size(),
embedder_max_length=context.get_max_length(),
embedder_use_cache=context.get_use_cache(),
)
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
"""
Fit the vector index using the provided utterances and labels.
:param utterances: List of text data to index.
:param labels: List of corresponding labels for the utterances.
"""
if hasattr(self, "_vector_index"):
self.clear_cache()
self._validate_task(labels)
self._vector_index = VectorIndex(
self.embedder_name,
self.embedder_device,
self.embedder_batch_size,
self.embedder_max_length,
self.embedder_use_cache,
)
self._vector_index.add(utterances, labels)
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
"""
Evaluate the embedding model using a specified metric function.
:param context: The context containing test data and labels.
:return: Computed metrics value for the test set or error code of metrics
"""
train_utterances, train_labels = self.get_train_data(context)
self.fit(train_utterances, train_labels)
val_utterances = context.data_handler.validation_utterances(0)
val_labels = context.data_handler.validation_labels(0)
predictions = self.predict(val_utterances)
metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics_ho((val_labels, predictions), chosen_metrics)
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
metrics_calculated, _ = self.score_metrics_cv(chosen_metrics, context.data_handler.validation_iterator())
return metrics_calculated
def get_assets(self) -> RetrieverArtifact:
"""
Get the retriever artifacts for this module.
:return: A RetrieverArtifact object containing embedder information.
"""
return RetrieverArtifact(embedder_name=self.embedder_name)
def clear_cache(self) -> None:
"""Clear cached data in memory used by the vector index."""
self._vector_index.clear_ram()
def predict(self, utterances: list[str]) -> list[ListOfLabels]:
"""
Predict the nearest neighbors for a list of utterances.
:param utterances: List of utterances for which nearest neighbors are to be retrieved.
:return: List of labels for each retrieved utterance.
"""
predictions, _, _ = self._vector_index.query(utterances, self.k)
return predictions