-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy path_retrieval.py
More file actions
148 lines (116 loc) · 5.34 KB
/
_retrieval.py
File metadata and controls
148 lines (116 loc) · 5.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""RetrievalAimedEmbedding class for a proxy optimization of embedding."""
from typing import Any
from pydantic import PositiveInt
from autointent import Context, VectorIndex
from autointent.configs import EmbedderConfig
from autointent.context.optimization_info import EmbeddingArtifact
from autointent.custom_types import ListOfLabels
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
from autointent.modules.abc import BaseEmbedding
class RetrievalAimedEmbedding(BaseEmbedding):
r"""
Module for configuring embeddings optimized for retrieval tasks.
The main purpose of this module is to be used at embedding node for optimizing
embedding configuration using its retrieval quality as a sort of proxy metric.
:ivar _vector_index: The vector index used for nearest neighbor retrieval.
:ivar name: Name of the module, defaults to "retrieval".
Examples
--------
.. testcode::
from autointent.modules.embedding import RetrievalAimedEmbedding
utterances = ["bye", "how are you?", "good morning"]
labels = [0, 1, 1]
retrieval = RetrievalAimedEmbedding(
k=2,
embedder_config="sergeyzh/rubert-tiny-turbo",
)
retrieval.fit(utterances, labels)
"""
_vector_index: VectorIndex
name = "retrieval"
supports_multiclass = True
supports_multilabel = True
supports_oos = False
def __init__(
self,
embedder_config: EmbedderConfig | str | dict[str, Any],
k: PositiveInt = 10,
) -> None:
"""
Initialize the RetrievalAimedEmbedding.
:param k: Number of nearest neighbors to retrieve.
:param embedder_config: Config of the embedder used for creating embeddings.
"""
self.k = k
embedder_config = EmbedderConfig.from_search_config(embedder_config)
self.embedder_config = embedder_config
if self.k < 0 or not isinstance(self.k, int):
msg = "`k` argument of `RetrievalAimedEmbedding` must be a positive int"
raise ValueError(msg)
@classmethod
def from_context(
cls,
context: Context,
embedder_config: EmbedderConfig | str,
k: PositiveInt = 10,
) -> "RetrievalAimedEmbedding":
"""
Create an instance using a Context object.
:param context: The context containing configurations and utilities.
:param k: Number of nearest neighbors to retrieve.
:param embedder_config: Config of the embedder to use.
:return: Initialized RetrievalAimedEmbedding instance.
"""
return cls(
k=k,
embedder_config=embedder_config,
)
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
"""
Fit the vector index using the provided utterances and labels.
:param utterances: List of text data to index.
:param labels: List of corresponding labels for the utterances.
"""
if hasattr(self, "_vector_index"):
self.clear_cache()
self._validate_task(labels)
self._vector_index = VectorIndex(
self.embedder_config,
)
self._vector_index.add(utterances, labels)
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
"""
Evaluate the embedding model using a specified metric function.
:param context: The context containing test data and labels.
:return: Computed metrics value for the test set or error code of metrics
"""
train_utterances, train_labels = self.get_train_data(context)
self.fit(train_utterances, train_labels)
val_utterances = context.data_handler.validation_utterances(0)
val_labels = context.data_handler.validation_labels(0)
predictions = self.predict(val_utterances)
metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
return self.score_metrics_ho((val_labels, predictions), chosen_metrics)
def score_cv(self, context: Context, metrics: list[str]) -> dict[str, float]:
metrics_dict = RETRIEVAL_METRICS_MULTILABEL if context.is_multilabel() else RETRIEVAL_METRICS_MULTICLASS
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
metrics_calculated, _ = self.score_metrics_cv(chosen_metrics, context.data_handler.validation_iterator())
return metrics_calculated
def get_assets(self) -> EmbeddingArtifact:
"""
Get the retriever artifacts for this module.
:return: A EmbeddingArtifact object containing embedder information.
"""
return EmbeddingArtifact(config=self.embedder_config)
def clear_cache(self) -> None:
"""Clear cached data in memory used by the vector index."""
self._vector_index.clear_ram()
def predict(self, utterances: list[str]) -> list[ListOfLabels]:
"""
Predict the nearest neighbors for a list of utterances.
:param utterances: List of utterances for which nearest neighbors are to be retrieved.
:return: List of labels for each retrieved utterance.
"""
predictions, _, _ = self._vector_index.query(utterances, self.k)
return predictions