Skip to content

Commit 6aa7abc

Browse files
committed
integrate embeddings fine-tuning into Embedding modules
1 parent cb9b2ea commit 6aa7abc

File tree

7 files changed

+87
-31
lines changed

7 files changed

+87
-31
lines changed

autointent/_wrappers/embedder.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tempfile
1111
from functools import lru_cache
1212
from pathlib import Path
13+
from uuid import uuid4
1314

1415
import huggingface_hub
1516
import numpy as np
@@ -26,6 +27,7 @@
2627

2728
from autointent._hash import Hasher
2829
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum
30+
from autointent.custom_types import ListOfLabels
2931

3032
logger = logging.getLogger(__name__)
3133

@@ -72,7 +74,9 @@ class Embedder:
7274
"""
7375

7476
_metadata_dict_name: str = "metadata.json"
77+
_weights_dir_name: str = "sentence_transformer"
7578
_dump_dir: Path | None = None
79+
_trained: bool = False
7680

7781
def __init__(self, embedder_config: EmbedderConfig) -> None:
7882
"""Initialize the Embedder.
@@ -89,7 +93,7 @@ def _get_hash(self) -> int:
8993
The hash value of the Embedder.
9094
"""
9195
hasher = Hasher()
92-
if self.config.freeze:
96+
if not Path(self.config.model_name).exists():
9397
commit_hash = _get_latest_commit_hash(self.config.model_name)
9498
hasher.update(commit_hash)
9599
else:
@@ -113,8 +117,22 @@ def _load_model(self) -> SentenceTransformer:
113117
res = self.embedding_model
114118
return res
115119

116-
def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTuningConfig) -> None:
120+
def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFineTuningConfig) -> None:
117121
"""Train the embedding model."""
122+
if len(utterances) != len(labels):
123+
msg = f"Utterances and labels lists lengths mismatch: {len(utterances)=} != {len(labels)=}"
124+
raise ValueError(msg)
125+
126+
if len(labels) == 0:
127+
msg = "Empty data"
128+
raise ValueError(msg)
129+
130+
# TODO support multi-label data
131+
if isinstance(labels[0], list):
132+
msg = "Multi-label data is not supported for embeddings fine-tuning for now"
133+
logger.warning(msg)
134+
return
135+
118136
self._load_model()
119137
if config.early_stopping:
120138
x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=0.1, random_state=42)
@@ -131,8 +149,7 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
131149
output_dir=tmp_dir,
132150
num_train_epochs=config.epoch_num,
133151
per_device_train_batch_size=config.batch_size,
134-
per_device_eval_batch_size=8,
135-
eval_steps=1,
152+
per_device_eval_batch_size=config.batch_size,
136153
learning_rate=config.learning_rate,
137154
warmup_ratio=config.warmup_ratio,
138155
fp16=config.fp16,
@@ -143,9 +160,9 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
143160
eval_strategy="epoch",
144161
greater_is_better=False,
145162
)
146-
callback: list[TrainerCallback] = []
163+
callbacks: list[TrainerCallback] = []
147164
if config.early_stopping:
148-
callback.append(
165+
callbacks.append(
149166
EarlyStoppingCallback(
150167
early_stopping_patience=config.early_stopping,
151168
early_stopping_threshold=config.early_stopping_threshold,
@@ -157,11 +174,18 @@ def train(self, utterances: list[str], labels: list[int], config: EmbedderFineTu
157174
train_dataset=tr_ds,
158175
eval_dataset=val_ds,
159176
loss=loss,
160-
callbacks=callback,
177+
callbacks=callbacks,
161178
)
162179

163180
trainer.train()
164181

182+
# use temporary path for re-usage
183+
model_path = str(Path(tempfile.mkdtemp("autointent_embedders")) / str(uuid4()))
184+
self.embedding_model.save(model_path)
185+
self.config.model_name = model_path
186+
187+
self._trained = True
188+
165189
def clear_ram(self) -> None:
166190
"""Move the embedding model to CPU and delete it from memory."""
167191
if hasattr(self, "embedding_model"):
@@ -182,6 +206,11 @@ def dump(self, path: Path) -> None:
182206
Args:
183207
path: Path to the directory where the model will be saved.
184208
"""
209+
if self._trained:
210+
model_path = str((path / self._weights_dir_name).resolve())
211+
self.embedding_model.save(model_path, create_model_card=False)
212+
self.config.model_name = model_path
213+
185214
self._dump_dir = path
186215
path.mkdir(parents=True, exist_ok=True)
187216
with (path / self._metadata_dict_name).open("w") as file:

autointent/_wrappers/vector_index/vector_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ class VectorIndex:
3838
embedder: Embedder
3939
index: BaseIndexBackend
4040

41-
def __init__(self, embedder_config: EmbedderConfig, config: VectorIndexConfig) -> None:
41+
def __init__(self, embedder_config: EmbedderConfig | Embedder, config: VectorIndexConfig) -> None:
4242
"""Initialize the VectorIndex with an embedding model.
4343
4444
Args:
4545
embedder_config: Configuration for the embedding model.
4646
config: settings for vector index.
4747
backend: vector index backend to use.
4848
"""
49-
self.embedder = Embedder(embedder_config)
49+
self.embedder = embedder_config if isinstance(embedder_config, Embedder) else Embedder(embedder_config)
5050
self.config = config
5151

5252
def _init_index(self, vector_size: int) -> BaseIndexBackend:

autointent/configs/_transformers.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
5-
from typing_extensions import Self
5+
from typing_extensions import Self, assert_never
66

77
from autointent.custom_types import FloatFromZeroToOne
88
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -26,6 +26,16 @@ class EmbedderFineTuningConfig(BaseModel):
2626
fp16: bool = Field(default=False)
2727
bf16: bool = Field(default=False)
2828

29+
@classmethod
30+
def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self | None:
31+
if isinstance(values, BaseModel):
32+
return cls(**values.model_dump())
33+
if isinstance(values, dict):
34+
return cls(**values)
35+
if values is None:
36+
return None
37+
assert_never(values)
38+
2939

3040
class HFModelConfig(BaseModel):
3141
model_config = ConfigDict(extra="forbid")
@@ -54,7 +64,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
5464
if values is None:
5565
return cls()
5666
if isinstance(values, BaseModel):
57-
return values # type: ignore[return-value]
67+
return cls(**values.model_dump())
5868
if isinstance(values, str):
5969
return cls(model_name=values)
6070
return cls(**values)
@@ -85,7 +95,6 @@ class EmbedderConfig(HFModelConfig):
8595
"cosine", description="Name of the similarity function to use."
8696
)
8797
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
88-
freeze: bool = Field(True, description="Whether to freeze the model parameters.")
8998

9099
def get_prompt_config(self) -> dict[str, str] | None:
91100
"""Get the prompt config for the given prompt type.
@@ -174,5 +183,7 @@ def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self:
174183
if values is None:
175184
return cls()
176185
if isinstance(values, BaseModel):
177-
return values # type: ignore[return-value]
178-
return cls(**values)
186+
return cls(**values.model_dump())
187+
if isinstance(values, dict):
188+
return cls(**values)
189+
assert_never(values)

autointent/context/optimization_info/_optimization_info.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from autointent.configs import EmbedderConfig, InferenceNodeConfig
2121
from autointent.custom_types import NodeType
2222

23-
from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials
23+
from ._data_models import Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials
2424

2525
if TYPE_CHECKING:
2626
from autointent.modules.base import BaseModule
@@ -95,7 +95,6 @@ def log_module_optimization(
9595
metric_value: float,
9696
metric_name: str,
9797
metrics: dict[str, float],
98-
artifact: Artifact,
9998
module_dump_dir: str | None,
10099
module: "BaseModule",
101100
) -> None:
@@ -108,7 +107,6 @@ def log_module_optimization(
108107
metric_value: Metric value achieved by the module.
109108
metric_name: Name of the evaluation metric.
110109
metrics: Dictionary of metric names and their values.
111-
artifact: Artifact generated by the module.
112110
module_dump_dir: Directory where the module is dumped.
113111
module: The module instance, if available.
114112
"""
@@ -117,7 +115,7 @@ def log_module_optimization(
117115
self.modules.add_module(node_type, module)
118116
if module_dump_dir is not None:
119117
module.dump(module_dump_dir)
120-
self.artifacts.add_artifact(node_type, artifact)
118+
self.artifacts.add_artifact(node_type, module.get_assets())
121119

122120
if old_best_metric_value_idx is not None:
123121
prev_best_dump = self.trials.get_trials(node_type)[old_best_metric_value_idx].module_dump_dir

autointent/modules/embedding/_logreg.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sklearn.preprocessing import LabelEncoder
1111

1212
from autointent import Context, Embedder
13-
from autointent.configs import EmbedderConfig, TaskTypeEnum
13+
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum
1414
from autointent.context.optimization_info import EmbeddingArtifact
1515
from autointent.custom_types import ListOfLabels
1616
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -26,6 +26,7 @@ class LogregAimedEmbedding(BaseEmbedding):
2626
Args:
2727
embedder_config: Config of the embedder used for creating embeddings
2828
cv: Number of folds used in LogisticRegressionCV
29+
ft_config: settings for fine-tuning embeddings
2930
3031
Examples:
3132
--------
@@ -52,9 +53,11 @@ def __init__(
5253
self,
5354
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
5455
cv: PositiveInt = 3,
56+
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
5557
) -> None:
56-
self.embedder_config = EmbedderConfig.from_search_config(embedder_config)
58+
self._embedder = Embedder(EmbedderConfig.from_search_config(embedder_config))
5759
self.cv = cv
60+
self.ft_config = EmbedderFineTuningConfig.from_search_config(ft_config)
5861

5962
if self.cv < 0 or not isinstance(self.cv, int):
6063
msg = "`cv` argument of `LogregAimedEmbedding` must be a positive int"
@@ -65,6 +68,7 @@ def from_context(
6568
cls,
6669
context: Context,
6770
embedder_config: EmbedderConfig | str | None = None,
71+
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
6872
cv: PositiveInt = 3,
6973
) -> "LogregAimedEmbedding":
7074
"""Create a LogregAimedEmbedding instance using a Context object.
@@ -73,10 +77,12 @@ def from_context(
7377
context: Context containing configurations and utilities
7478
cv: Number of folds used in LogisticRegressionCV
7579
embedder_config: Config of the embedder to use
80+
ft_config: settings for fine-tuning embeddings
7681
"""
7782
return cls(
7883
cv=cv,
7984
embedder_config=embedder_config,
85+
ft_config=ft_config,
8086
)
8187

8288
def clear_cache(self) -> None:
@@ -93,9 +99,9 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
9399
"""
94100
self._validate_task(labels)
95101

96-
self._embedder = Embedder(
97-
self.embedder_config,
98-
)
102+
if self.ft_config is not None:
103+
self._embedder.train(utterances=utterances, labels=labels, config=self.ft_config)
104+
99105
embeddings = self._embedder.embed(utterances, TaskTypeEnum.classification)
100106

101107
if self._multilabel:
@@ -153,7 +159,7 @@ def get_assets(self) -> EmbeddingArtifact:
153159
Returns:
154160
EmbeddingArtifact object containing embedder information
155161
"""
156-
return EmbeddingArtifact(config=self.embedder_config)
162+
return EmbeddingArtifact(config=self._embedder.config)
157163

158164
def predict(self, utterances: list[str]) -> NDArray[np.float64]:
159165
"""Predict probabilities for input utterances.

autointent/modules/embedding/_retrieval.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44

55
from pydantic import PositiveInt
66

7-
from autointent import Context, VectorIndex
8-
from autointent.configs import EmbedderConfig, VectorIndexConfig, get_default_vector_index_config
7+
from autointent import Context, Embedder, VectorIndex
8+
from autointent.configs import (
9+
EmbedderConfig,
10+
EmbedderFineTuningConfig,
11+
VectorIndexConfig,
12+
get_default_vector_index_config,
13+
)
914
from autointent.context.optimization_info import EmbeddingArtifact
1015
from autointent.custom_types import ListOfLabels
1116
from autointent.metrics import RETRIEVAL_METRICS_MULTICLASS, RETRIEVAL_METRICS_MULTILABEL
@@ -21,6 +26,7 @@ class RetrievalAimedEmbedding(BaseEmbedding):
2126
Args:
2227
k: Number of nearest neighbors to retrieve
2328
embedder_config: Config of the embedder used for creating embeddings
29+
ft_config: settings for fine-tuning embeddings
2430
2531
Examples:
2632
--------
@@ -49,11 +55,12 @@ def __init__(
4955
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
5056
vector_index_config: VectorIndexConfig | None = None,
5157
k: PositiveInt = 10,
58+
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
5259
) -> None:
5360
self.k = k
54-
embedder_config = EmbedderConfig.from_search_config(embedder_config)
55-
self.embedder_config = embedder_config
61+
self._embedder = Embedder(EmbedderConfig.from_search_config(embedder_config))
5662
self.vector_index_config = vector_index_config or get_default_vector_index_config()
63+
self.ft_config = EmbedderFineTuningConfig.from_search_config(ft_config)
5764

5865
if self.k < 0 or not isinstance(self.k, int):
5966
msg = "`k` argument of `RetrievalAimedEmbedding` must be a positive int"
@@ -65,18 +72,21 @@ def from_context(
6572
context: Context,
6673
embedder_config: EmbedderConfig | str | None = None,
6774
k: PositiveInt = 10,
75+
ft_config: EmbedderFineTuningConfig | dict[str, Any] | None = None,
6876
) -> "RetrievalAimedEmbedding":
6977
"""Create an instance using a Context object.
7078
7179
Args:
7280
context: The context containing configurations and utilities
7381
k: Number of nearest neighbors to retrieve
7482
embedder_config: Config of the embedder to use
83+
ft_config: settings for fine-tuning embeddings
7584
"""
7685
return cls(
7786
k=k,
7887
embedder_config=embedder_config,
7988
vector_index_config=context.vector_index_config,
89+
ft_config=ft_config,
8090
)
8191

8292
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
@@ -88,7 +98,10 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
8898
"""
8999
self._validate_task(labels)
90100

91-
self._vector_index = VectorIndex(self.embedder_config, config=self.vector_index_config)
101+
if self.ft_config is not None:
102+
self._embedder.train(utterances=utterances, labels=labels, config=self.ft_config)
103+
104+
self._vector_index = VectorIndex(self._embedder, config=self.vector_index_config)
92105
self._vector_index.add(utterances, labels)
93106

94107
def score_ho(self, context: Context, metrics: list[str]) -> dict[str, float]:
@@ -134,7 +147,7 @@ def get_assets(self) -> EmbeddingArtifact:
134147
Returns:
135148
A EmbeddingArtifact object containing embedder information
136149
"""
137-
return EmbeddingArtifact(config=self.embedder_config)
150+
return EmbeddingArtifact(config=self._embedder.config)
138151

139152
def clear_cache(self) -> None:
140153
"""Clear cached data in memory used by the vector index."""

autointent/nodes/_node_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ def objective(
156156
metric_value=target_metric,
157157
metric_name=self.target_metric,
158158
metrics=quality_metrics,
159-
artifact=module.get_assets(), # retriever name / scores / predictions
160159
module_dump_dir=self.get_module_dump_dir(context, module_name, self._counter),
161160
module=module,
162161
)

0 commit comments

Comments
 (0)