Skip to content

Commit ed82eae

Browse files
k0lenk4voorhsgithub-actions[bot]
authored
feat/train-embeddings (#246)
* add train * fixed env * deleted kwargs and local savings, added config * added test for train method * add EmbedderFineTuningConfig to __init__ * correct __init__ in config, remov pytest in test file * correct some syntax isues * move batch_size to EmbedderFineTuningConfig * add __init__.py to /test/embedder * Remove whitespace from blank line * correct errors * the number of epochs and train objects have been increased * made lint * add early stopping * remake train args * make a list of callbacks * inline type annotation of variable "callback" * change save_strategy to "epoch" * default value of fp16 changed to False * integrate embeddings fine-tuning into Embedding modules * Update optimizer_config.schema.json * clean up `freeze` throughout tests and tutorials * add comprehensive tests * embedder_model -> _model * fix early stopping * fix tests * clear ram bug fix * try to fix windows cleanup issue --------- Co-authored-by: voorhs <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent e5256c5 commit ed82eae

File tree

24 files changed

+835
-55
lines changed

24 files changed

+835
-55
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,5 @@ vector_db*
182182
*.db
183183
*.sqlite
184184
/wandb
185+
model_output/
186+
my.py

docs/optimizer_config.schema.json

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,12 +266,6 @@
266266
"description": "Whether to use embeddings caching.",
267267
"title": "Use Cache",
268268
"type": "boolean"
269-
},
270-
"freeze": {
271-
"default": true,
272-
"description": "Whether to freeze the model parameters.",
273-
"title": "Freeze",
274-
"type": "boolean"
275269
}
276270
},
277271
"title": "EmbedderConfig",
@@ -578,8 +572,7 @@
578572
"query_prompt": null,
579573
"passage_prompt": null,
580574
"similarity_fn_name": "cosine",
581-
"use_cache": true,
582-
"freeze": true
575+
"use_cache": true
583576
}
584577
},
585578
"cross_encoder_config": {

src/autointent/_wrappers/embedder.py

Lines changed: 100 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,27 @@
77
import json
88
import logging
99
import shutil
10+
import tempfile
1011
from functools import lru_cache
1112
from pathlib import Path
13+
from uuid import uuid4
1214

1315
import huggingface_hub
1416
import numpy as np
1517
import numpy.typing as npt
1618
import torch
1719
from appdirs import user_cache_dir
18-
from sentence_transformers import SentenceTransformer
20+
from datasets import Dataset
21+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
22+
from sentence_transformers.losses import BatchAllTripletLoss
1923
from sentence_transformers.similarity_functions import SimilarityFunction
24+
from sentence_transformers.training_args import BatchSamplers
25+
from sklearn.model_selection import train_test_split
26+
from transformers import EarlyStoppingCallback, TrainerCallback
2027

2128
from autointent._hash import Hasher
22-
from autointent.configs import EmbedderConfig, TaskTypeEnum
29+
from autointent.configs import EmbedderConfig, EmbedderFineTuningConfig, TaskTypeEnum
30+
from autointent.custom_types import ListOfLabels
2331

2432
logger = logging.getLogger(__name__)
2533

@@ -66,15 +74,18 @@ class Embedder:
6674
"""
6775

6876
_metadata_dict_name: str = "metadata.json"
77+
_weights_dir_name: str = "sentence_transformer"
6978
_dump_dir: Path | None = None
79+
_trained: bool = False
80+
_model: SentenceTransformer
7081

7182
def __init__(self, embedder_config: EmbedderConfig) -> None:
7283
"""Initialize the Embedder.
7384
7485
Args:
7586
embedder_config: Config of embedder.
7687
"""
77-
self.config = embedder_config
88+
self.config = embedder_config.model_copy(deep=True)
7889

7990
def _get_hash(self) -> int:
8091
"""Compute a hash value for the Embedder.
@@ -83,19 +94,19 @@ def _get_hash(self) -> int:
8394
The hash value of the Embedder.
8495
"""
8596
hasher = Hasher()
86-
if self.config.freeze:
97+
if not Path(self.config.model_name).exists():
8798
commit_hash = _get_latest_commit_hash(self.config.model_name)
8899
hasher.update(commit_hash)
89100
else:
90-
self.embedding_model = self._load_model()
91-
for parameter in self.embedding_model.parameters():
101+
self._model = self._load_model()
102+
for parameter in self._model.parameters():
92103
hasher.update(parameter.detach().cpu().numpy())
93104
hasher.update(self.config.tokenizer_config.max_length)
94105
return hasher.intdigest()
95106

96107
def _load_model(self) -> SentenceTransformer:
97108
"""Load sentence transformers model to device."""
98-
if not hasattr(self, "embedding_model"):
109+
if not hasattr(self, "_model"):
99110
res = SentenceTransformer(
100111
self.config.model_name,
101112
device=self.config.device,
@@ -104,15 +115,80 @@ def _load_model(self) -> SentenceTransformer:
104115
trust_remote_code=self.config.trust_remote_code,
105116
)
106117
else:
107-
res = self.embedding_model
118+
res = self._model
108119
return res
109120

121+
def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFineTuningConfig) -> None:
122+
"""Train the embedding model."""
123+
if len(utterances) != len(labels):
124+
msg = f"Utterances and labels lists lengths mismatch: {len(utterances)=} != {len(labels)=}"
125+
raise ValueError(msg)
126+
127+
if len(labels) == 0:
128+
msg = "Empty data"
129+
raise ValueError(msg)
130+
131+
# TODO support multi-label data
132+
if isinstance(labels[0], list):
133+
msg = "Multi-label data is not supported for embeddings fine-tuning for now"
134+
logger.warning(msg)
135+
return
136+
137+
self._model = self._load_model()
138+
139+
x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=config.val_fraction)
140+
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
141+
val_ds = Dataset.from_dict({"text": x_val, "label": y_val})
142+
143+
loss = BatchAllTripletLoss(model=self._model, margin=config.margin)
144+
with tempfile.TemporaryDirectory() as tmp_dir:
145+
args = SentenceTransformerTrainingArguments(
146+
save_strategy="epoch",
147+
save_total_limit=1,
148+
output_dir=tmp_dir,
149+
num_train_epochs=config.epoch_num,
150+
per_device_train_batch_size=config.batch_size,
151+
per_device_eval_batch_size=config.batch_size,
152+
learning_rate=config.learning_rate,
153+
warmup_ratio=config.warmup_ratio,
154+
fp16=config.fp16,
155+
bf16=config.bf16,
156+
batch_sampler=BatchSamplers.NO_DUPLICATES,
157+
metric_for_best_model="eval_loss",
158+
load_best_model_at_end=True,
159+
eval_strategy="epoch",
160+
greater_is_better=False,
161+
)
162+
callbacks: list[TrainerCallback] = [
163+
EarlyStoppingCallback(
164+
early_stopping_patience=config.early_stopping_patience,
165+
early_stopping_threshold=config.early_stopping_threshold,
166+
)
167+
]
168+
trainer = SentenceTransformerTrainer(
169+
model=self._model,
170+
args=args,
171+
train_dataset=tr_ds,
172+
eval_dataset=val_ds,
173+
loss=loss,
174+
callbacks=callbacks,
175+
)
176+
177+
trainer.train()
178+
179+
# use temporary path for re-usage
180+
model_path = str(Path(tempfile.mkdtemp("autointent_embedders")) / str(uuid4()))
181+
self._model.save(model_path)
182+
self.config.model_name = model_path
183+
184+
self._trained = True
185+
110186
def clear_ram(self) -> None:
111187
"""Move the embedding model to CPU and delete it from memory."""
112-
if hasattr(self, "embedding_model"):
188+
if hasattr(self, "_model"):
113189
logger.debug("Clearing embedder %s from memory", self.config.model_name)
114-
self.embedding_model.cpu()
115-
del self.embedding_model
190+
self._model.cpu()
191+
del self._model
116192
torch.cuda.empty_cache()
117193

118194
def delete(self) -> None:
@@ -127,6 +203,11 @@ def dump(self, path: Path) -> None:
127203
Args:
128204
path: Path to the directory where the model will be saved.
129205
"""
206+
if self._trained:
207+
model_path = str((path / self._weights_dir_name).resolve())
208+
self._model.save(model_path, create_model_card=False)
209+
self.config.model_name = model_path
210+
130211
self._dump_dir = path
131212
path.mkdir(parents=True, exist_ok=True)
132213
with (path / self._metadata_dict_name).open("w") as file:
@@ -164,6 +245,11 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
164245
Returns:
165246
A numpy array of embeddings.
166247
"""
248+
if len(utterances) == 0:
249+
msg = "Empty input"
250+
logger.error(msg)
251+
raise ValueError(msg)
252+
167253
prompt = self.config.get_prompt(task_type)
168254

169255
if self.config.use_cache:
@@ -179,7 +265,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
179265
logger.debug("loading embeddings from %s", str(embeddings_path))
180266
return np.load(embeddings_path) # type: ignore[no-any-return]
181267

182-
self.embedding_model = self._load_model()
268+
self._model = self._load_model()
183269

184270
logger.debug(
185271
"Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s",
@@ -191,9 +277,9 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
191277
)
192278

193279
if self.config.tokenizer_config.max_length is not None:
194-
self.embedding_model.max_seq_length = self.config.tokenizer_config.max_length
280+
self._model.max_seq_length = self.config.tokenizer_config.max_length
195281

196-
embeddings = self.embedding_model.encode(
282+
embeddings = self._model.encode(
197283
utterances,
198284
convert_to_numpy=True,
199285
batch_size=self.config.batch_size,

src/autointent/_wrappers/vector_index/vector_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ class VectorIndex:
3838
embedder: Embedder
3939
index: BaseIndexBackend
4040

41-
def __init__(self, embedder_config: EmbedderConfig, config: VectorIndexConfig) -> None:
41+
def __init__(self, embedder_config: EmbedderConfig | Embedder, config: VectorIndexConfig) -> None:
4242
"""Initialize the VectorIndex with an embedding model.
4343
4444
Args:
4545
embedder_config: Configuration for the embedding model.
4646
config: settings for vector index.
4747
backend: vector index backend to use.
4848
"""
49-
self.embedder = Embedder(embedder_config)
49+
self.embedder = embedder_config if isinstance(embedder_config, Embedder) else Embedder(embedder_config)
5050
self.config = config
5151

5252
def _init_index(self, vector_size: int) -> BaseIndexBackend:

src/autointent/configs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
CrossEncoderConfig,
88
EarlyStoppingConfig,
99
EmbedderConfig,
10+
EmbedderFineTuningConfig,
1011
HFModelConfig,
1112
TaskTypeEnum,
1213
TokenizerConfig,
@@ -18,6 +19,7 @@
1819
"DataConfig",
1920
"EarlyStoppingConfig",
2021
"EmbedderConfig",
22+
"EmbedderFineTuningConfig",
2123
"FaissConfig",
2224
"HFModelConfig",
2325
"HPOConfig",

src/autointent/configs/_transformers.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Any, Literal
33

44
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
5-
from typing_extensions import Self
5+
from typing_extensions import Self, assert_never
66

77
from autointent.custom_types import FloatFromZeroToOne
88
from autointent.metrics import SCORING_METRICS_MULTICLASS, SCORING_METRICS_MULTILABEL
@@ -15,6 +15,29 @@ class TokenizerConfig(BaseModel):
1515
max_length: PositiveInt | None = Field(None, description="Maximum length of input sequences.")
1616

1717

18+
class EmbedderFineTuningConfig(BaseModel):
19+
epoch_num: int
20+
batch_size: int
21+
margin: float = Field(default=0.5)
22+
learning_rate: float = Field(default=2e-5)
23+
warmup_ratio: float = Field(default=0.1)
24+
early_stopping_patience: int = Field(default=1)
25+
early_stopping_threshold: float = Field(default=0.0)
26+
val_fraction: float = Field(default=0.2)
27+
fp16: bool = Field(default=False)
28+
bf16: bool = Field(default=False)
29+
30+
@classmethod
31+
def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self | None:
32+
if isinstance(values, BaseModel):
33+
return cls(**values.model_dump())
34+
if isinstance(values, dict):
35+
return cls(**values)
36+
if values is None:
37+
return None
38+
assert_never(values)
39+
40+
1841
class HFModelConfig(BaseModel):
1942
model_config = ConfigDict(extra="forbid")
2043
model_name: str = Field(
@@ -42,7 +65,7 @@ def from_search_config(cls, values: dict[str, Any] | str | BaseModel | None) ->
4265
if values is None:
4366
return cls()
4467
if isinstance(values, BaseModel):
45-
return values # type: ignore[return-value]
68+
return cls(**values.model_dump())
4669
if isinstance(values, str):
4770
return cls(model_name=values)
4871
return cls(**values)
@@ -73,7 +96,6 @@ class EmbedderConfig(HFModelConfig):
7396
"cosine", description="Name of the similarity function to use."
7497
)
7598
use_cache: bool = Field(True, description="Whether to use embeddings caching.")
76-
freeze: bool = Field(True, description="Whether to freeze the model parameters.")
7799

78100
def get_prompt_config(self) -> dict[str, str] | None:
79101
"""Get the prompt config for the given prompt type.
@@ -162,5 +184,7 @@ def from_search_config(cls, values: dict[str, Any] | BaseModel | None) -> Self:
162184
if values is None:
163185
return cls()
164186
if isinstance(values, BaseModel):
165-
return values # type: ignore[return-value]
166-
return cls(**values)
187+
return cls(**values.model_dump())
188+
if isinstance(values, dict):
189+
return cls(**values)
190+
assert_never(values)

src/autointent/context/optimization_info/_optimization_info.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from autointent.configs import EmbedderConfig, InferenceNodeConfig
2121
from autointent.custom_types import NodeType
2222

23-
from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials
23+
from ._data_models import Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials
2424

2525
if TYPE_CHECKING:
2626
from autointent.modules.base import BaseModule
@@ -95,7 +95,6 @@ def log_module_optimization(
9595
metric_value: float,
9696
metric_name: str,
9797
metrics: dict[str, float],
98-
artifact: Artifact,
9998
module_dump_dir: str | None,
10099
module: "BaseModule",
101100
) -> None:
@@ -108,7 +107,6 @@ def log_module_optimization(
108107
metric_value: Metric value achieved by the module.
109108
metric_name: Name of the evaluation metric.
110109
metrics: Dictionary of metric names and their values.
111-
artifact: Artifact generated by the module.
112110
module_dump_dir: Directory where the module is dumped.
113111
module: The module instance, if available.
114112
"""
@@ -117,7 +115,7 @@ def log_module_optimization(
117115
self.modules.add_module(node_type, module)
118116
if module_dump_dir is not None:
119117
module.dump(module_dump_dir)
120-
self.artifacts.add_artifact(node_type, artifact)
118+
self.artifacts.add_artifact(node_type, module.get_assets())
121119

122120
if old_best_metric_value_idx is not None:
123121
prev_best_dump = self.trials.get_trials(node_type)[old_best_metric_value_idx].module_dump_dir

0 commit comments

Comments
 (0)